Skip to content

Commit

Permalink
Add TILEDB_DATETIME_DAY type support for Arrow (#2002)
Browse files Browse the repository at this point in the history
* Add in place buffer shift for TILEDB_DATETIME_DAY
* Add tests
  • Loading branch information
kounelisagis authored Oct 23, 2024
1 parent 61b1ce2 commit f206545
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 7 deletions.
32 changes: 26 additions & 6 deletions tiledb/py_arrow_io_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ ArrowInfo tiledb_buffer_arrow_fmt(BufferInfo bufferinfo, bool use_list = true) {
return ArrowInfo("tsu:");
case TILEDB_DATETIME_NS:
return ArrowInfo("tsn:");
case TILEDB_DATETIME_DAY:
return ArrowInfo("tdD");
// TILEDB_BOOL is stored as a uint8_t but arrow::Type::BOOL is 1 bit
case TILEDB_BOOL:
return ArrowInfo("C");
Expand All @@ -242,7 +244,6 @@ ArrowInfo tiledb_buffer_arrow_fmt(BufferInfo bufferinfo, bool use_list = true) {
case TILEDB_DATETIME_YEAR:
case TILEDB_DATETIME_MONTH:
case TILEDB_DATETIME_WEEK:
case TILEDB_DATETIME_DAY:
case TILEDB_DATETIME_HR:
case TILEDB_DATETIME_MIN:
case TILEDB_DATETIME_PS:
Expand Down Expand Up @@ -739,6 +740,14 @@ int64_t flags_for_buffer(BufferInfo binfo) {
return 0;
}

template <typename T> T cast_checked(uint64_t val) {
if (val > std::numeric_limits<T>::max()) {
throw tiledb::TileDBError(
"[TileDB-Arrow] Value too large to cast to requested type");
}
return static_cast<T>(val);
}

void ArrowExporter::export_(const std::string &name, ArrowArray *array,
ArrowSchema *schema, ArrowAdapter::release_cb cb,
void *cb_data) {
Expand All @@ -762,13 +771,11 @@ void ArrowExporter::export_(const std::string &name, ArrowArray *array,
if (bufferinfo.is_var) {
buffers = {nullptr, bufferinfo.offsets, bufferinfo.data};
} else {
cpp_schema = new CPPArrowSchema(name, arrow_fmt.fmt_, std::nullopt,
arrow_flags, {}, {});
buffers = {nullptr, bufferinfo.data};
}
cpp_schema->export_ptr(schema);

size_t elem_num = 0;
size_t elem_num = bufferinfo.data_num;
if (bufferinfo.is_var) {
// adjust for arrow offset unless empty result
elem_num = (bufferinfo.offsets_num == 0) ? 0 : bufferinfo.offsets_num - 1;
Expand All @@ -778,8 +785,21 @@ void ArrowExporter::export_(const std::string &name, ArrowArray *array,
// take the size of the entire buffer and divide by the size of each
// element
elem_num = bufferinfo.data_num / bufferinfo.tdbtype.cell_val_num;
} else {
elem_num = bufferinfo.data_num;
} else if (arrow_fmt.fmt_ == "tdD") {
// for Arrow date32 we only need the first 4 bytes of each 8-byte
// TILEDB_DATETIME_DAY element which we keep by in-place left shifting
for (size_t i = 0; i < bufferinfo.data_num; i++) {
uint32_t lost_data = *(reinterpret_cast<uint32_t *>(
static_cast<uint8_t *>(buffers[1]) + i * 8 + 4));
if (lost_data != 0) {
throw tiledb::TileDBError(
"[TileDB-Arrow] Non-zero data detected in the memory buffer at "
"position that will be overwritten");
}

static_cast<uint32_t *>(buffers[1])[i] =
cast_checked<uint32_t>(static_cast<uint64_t *>(buffers[1])[i]);
}
}
}

Expand Down
135 changes: 134 additions & 1 deletion tiledb/tests/test_pandas_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import string
import sys
import uuid
from collections import OrderedDict

import numpy as np
import pyarrow
Expand All @@ -16,6 +17,7 @@

from .common import (
DiskTestCase,
assert_dict_arrays_equal,
dtype_max,
dtype_min,
has_pandas,
Expand Down Expand Up @@ -219,7 +221,6 @@ def test_object_dtype(self):
"<M8[Y]",
"<M8[M]",
"<M8[W]",
"<M8[D]",
"<M8[h]",
"<M8[m]",
"<M8[s]",
Expand Down Expand Up @@ -1665,3 +1666,135 @@ def gen_array(sz):

with tiledb.open(uri) as A:
tm.assert_frame_equal(df.sort_index(), A.df[:])


def test_datetime64_days_dtype_read_sc25572(checked_path):
"""Test writing array with datetime64[D] attribute dtype and reading back"""

uri = checked_path.path()

schema = tiledb.ArraySchema(
tiledb.Domain(tiledb.Dim(name="d1", domain=(0, 99), tile=100, dtype=np.int64)),
[tiledb.Attr("Attr1", dtype="datetime64[D]", var=False, nullable=False)],
)
tiledb.Array.create(uri, schema)

data = OrderedDict(
[
(
"Attr1",
np.array(
[
np.datetime64(np.random.randint(0, 10000), "D")
for _ in range(100)
],
dtype="datetime64[D]",
),
)
]
)
original_df = pd.DataFrame(data, dtype="datetime64[ns]")
original_df.index.name = "d1"

with tiledb.open(uri, "w") as array:
array[:] = data

with tiledb.open(uri, "r") as array:
assert_dict_arrays_equal(array[:], data)
df_received = array.df[:]
df_received = df_received.set_index("d1")
tm.assert_frame_equal(
original_df, df_received, check_datetimelike_compat=True, check_dtype=False
)


def test_datetime64_days_dtype_write_sc25572(checked_path):
"""Test writing dataframe with datetime64[D] attribute dtype and reading back"""
uri = checked_path.path()

data = OrderedDict(
[
(
"Attr1",
np.array(
[
np.datetime64(np.random.randint(0, 10000), "D")
for _ in range(100)
],
dtype="datetime64[ns]",
),
)
]
)
original_df = pd.DataFrame(data)
tiledb.from_pandas(uri, original_df)

with tiledb.open(uri, "r") as array:
assert_dict_arrays_equal(array[:], data)
df_received = array.df[:]
tm.assert_frame_equal(
original_df, df_received, check_datetimelike_compat=True, check_dtype=False
)


def test_datetime64_days_dtype_read_out_of_range_sc25572(checked_path):
"""Test writing array with an out-of-range datetime64[D] attribute dtype"""
uri = checked_path.path()
schema = tiledb.ArraySchema(
tiledb.Domain(tiledb.Dim(name="d1", domain=(0, 9), tile=10, dtype=np.int64)),
[tiledb.Attr("Attr1", dtype="datetime64[D]", var=False, nullable=False)],
)
tiledb.Array.create(uri, schema)

data = OrderedDict(
[
(
"Attr1",
np.array(
[np.datetime64(10000000, "D") for _ in range(10)],
dtype="datetime64[D]",
),
)
]
)

with tiledb.open(uri, "w") as array:
array[:] = data

with tiledb.open(uri, "r") as array:
with pytest.raises(ValueError) as excinfo:
print(array.df[:])
assert "year 29349 is out of range" in str(excinfo.value)


def test_datetime64_days_dtype_read_overflow_sc25572(checked_path):
"""Test writing array with an out-of-range datetime64[D] attribute dtype"""
uri = checked_path.path()
schema = tiledb.ArraySchema(
tiledb.Domain(tiledb.Dim(name="d1", domain=(0, 9), tile=10, dtype=np.int64)),
[tiledb.Attr("Attr1", dtype="datetime64[D]", var=False, nullable=False)],
)
tiledb.Array.create(uri, schema)

data = OrderedDict(
[
(
"Attr1",
np.array(
[np.datetime64(10000000000, "D") for _ in range(10)],
dtype="datetime64[D]",
),
)
]
)

with tiledb.open(uri, "w") as array:
array[:] = data

with tiledb.open(uri, "r") as array:
with pytest.raises(tiledb.TileDBError) as excinfo:
print(array.df[:])
assert (
"[TileDB-Arrow] Non-zero data detected in the memory buffer at position that will be overwritten"
in str(excinfo.value)
)

0 comments on commit f206545

Please sign in to comment.