Skip to content

Commit

Permalink
Implement filter kernel for byte view arrays. (#5624)
Browse files Browse the repository at this point in the history
* Implement `filter` kernel for byte view arrays.

* Add unit tests and fix.

* Deprecate `ArrowPrimitiveType::get_byte_width`.

* Add string view filter benchmark.
  • Loading branch information
RinChanNOWWW authored Apr 15, 2024
1 parent fee6921 commit e88e5aa
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 10 deletions.
3 changes: 2 additions & 1 deletion arrow-arith/src/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use arrow_array::builder::BufferBuilder;
use arrow_array::types::ArrowDictionaryKeyType;
use arrow_array::*;
use arrow_buffer::buffer::NullBuffer;
use arrow_buffer::ArrowNativeType;
use arrow_buffer::{Buffer, MutableBuffer};
use arrow_data::ArrayData;
use arrow_schema::ArrowError;
Expand Down Expand Up @@ -386,7 +387,7 @@ where
O: ArrowPrimitiveType,
F: Fn(A::Item, B::Item) -> Result<O::Native, ArrowError>,
{
let mut buffer = MutableBuffer::new(len * O::get_byte_width());
let mut buffer = MutableBuffer::new(len * O::Native::get_byte_width());
for idx in 0..len {
unsafe {
buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?);
Expand Down
1 change: 1 addition & 0 deletions arrow-array/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub trait ArrowPrimitiveType: primitive::PrimitiveTypeSealed + 'static {
const DATA_TYPE: DataType;

/// Returns the byte width of this primitive type.
#[deprecated(note = "Use ArrowNativeType::get_byte_width")]
fn get_byte_width() -> usize {
std::mem::size_of::<Self::Native>()
}
Expand Down
5 changes: 5 additions & 0 deletions arrow-buffer/src/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ mod private {
pub trait ArrowNativeType:
std::fmt::Debug + Send + Sync + Copy + PartialOrd + Default + private::Sealed + 'static
{
/// Returns the byte width of this native type.
fn get_byte_width() -> usize {
std::mem::size_of::<Self>()
}

/// Convert native integer type from usize
///
/// Returns `None` if [`Self`] is not an integer or conversion would result
Expand Down
5 changes: 5 additions & 0 deletions arrow-data/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,11 @@ impl ArrayDataBuilder {
self
}

pub fn add_buffers(mut self, bs: Vec<Buffer>) -> Self {
self.buffers.extend(bs);
self
}

pub fn child_data(mut self, v: Vec<ArrayData>) -> Self {
self.child_data = v;
self
Expand Down
112 changes: 103 additions & 9 deletions arrow-select/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ use std::sync::Arc;
use arrow_array::builder::BooleanBufferBuilder;
use arrow_array::cast::AsArray;
use arrow_array::types::{
ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, RunEndIndexType,
ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
};
use arrow_array::*;
use arrow_buffer::{bit_util, BooleanBuffer, NullBuffer, RunEndBuffer};
use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer};
use arrow_buffer::{Buffer, MutableBuffer};
use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
use arrow_data::transform::MutableArrayData;
Expand Down Expand Up @@ -333,12 +333,18 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<Array
DataType::LargeUtf8 => {
Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
}
DataType::Utf8View => {
Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
}
DataType::Binary => {
Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
}
DataType::LargeBinary => {
Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
}
DataType::BinaryView => {
Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
}
DataType::RunEndEncoded(_, _) => {
downcast_run_array!{
values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
Expand Down Expand Up @@ -508,12 +514,8 @@ fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanA
BooleanArray::from(data)
}

/// `filter` implementation for primitive arrays
fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
where
T: ArrowPrimitiveType,
{
let values = array.values();
#[inline(never)]
fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
assert!(values.len() >= predicate.filter.len());

let buffer = match &predicate.strategy {
Expand Down Expand Up @@ -546,9 +548,19 @@ where
IterationStrategy::All | IterationStrategy::None => unreachable!(),
};

buffer.into()
}

/// `filter` implementation for primitive arrays
fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
where
T: ArrowPrimitiveType,
{
let values = array.values();
let buffer = filter_native(values, predicate);
let mut builder = ArrayDataBuilder::new(array.data_type().clone())
.len(predicate.count)
.add_buffer(buffer.into());
.add_buffer(buffer);

if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
Expand Down Expand Up @@ -673,6 +685,25 @@ where
GenericByteArray::from(data)
}

/// `filter` implementation for byte view arrays.
fn filter_byte_view<T: ByteViewType>(
array: &GenericByteViewArray<T>,
predicate: &FilterPredicate,
) -> GenericByteViewArray<T> {
let new_view_buffer = filter_native(array.views(), predicate);

let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
.len(predicate.count)
.add_buffer(new_view_buffer)
.add_buffers(array.data_buffers().to_vec());

if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
}

GenericByteViewArray::from(unsafe { builder.build_unchecked() })
}

/// `filter` implementation for dictionaries
fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
where
Expand Down Expand Up @@ -888,6 +919,69 @@ mod tests {
assert!(d.is_null(1));
}

fn _test_filter_byte_view<T>()
where
T: ByteViewType,
str: AsRef<T::Native>,
T::Native: PartialEq,
{
let array = {
// ["hello", "world", null, "large payload over 12 bytes", "lulu"]
let mut builder = GenericByteViewBuilder::<T>::new();
builder.append_value("hello");
builder.append_value("world");
builder.append_null();
builder.append_value("large payload over 12 bytes");
builder.append_value("lulu");
builder.finish()
};

{
let predicate = BooleanArray::from(vec![true, false, true, true, false]);
let actual = filter(&array, &predicate).unwrap();

assert_eq!(actual.len(), 3);

let expected = {
// ["hello", null, "large payload over 12 bytes"]
let mut builder = GenericByteViewBuilder::<T>::new();
builder.append_value("hello");
builder.append_null();
builder.append_value("large payload over 12 bytes");
builder.finish()
};

assert_eq!(actual.as_ref(), &expected);
}

{
let predicate = BooleanArray::from(vec![true, false, false, false, true]);
let actual = filter(&array, &predicate).unwrap();

assert_eq!(actual.len(), 2);

let expected = {
// ["hello", "lulu"]
let mut builder = GenericByteViewBuilder::<T>::new();
builder.append_value("hello");
builder.append_value("lulu");
builder.finish()
};

assert_eq!(actual.as_ref(), &expected);
}
}

#[test]
fn test_filter_string_view() {
_test_filter_byte_view::<StringViewType>()
}

#[test]
fn test_filter_binary_view() {
_test_filter_byte_view::<BinaryViewType>()
}

#[test]
fn test_filter_array_slice_with_null() {
let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
Expand Down
26 changes: 26 additions & 0 deletions arrow/benches/filter_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,32 @@ fn add_benchmark(c: &mut Criterion) {
c.bench_function("filter single record batch", |b| {
b.iter(|| filter_record_batch(&batch, &filter_array))
});

let data_array = create_string_view_array_with_len(size, 0.5, 4, false);
c.bench_function("filter context short string view (kept 1/2)", |b| {
b.iter(|| bench_built_filter(&filter, &data_array))
});
c.bench_function(
"filter context short string view high selectivity (kept 1023/1024)",
|b| b.iter(|| bench_built_filter(&dense_filter, &data_array)),
);
c.bench_function(
"filter context short string view low selectivity (kept 1/1024)",
|b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)),
);

let data_array = create_string_view_array_with_len(size, 0.5, 4, true);
c.bench_function("filter context mixed string view (kept 1/2)", |b| {
b.iter(|| bench_built_filter(&filter, &data_array))
});
c.bench_function(
"filter context mixed string view high selectivity (kept 1023/1024)",
|b| b.iter(|| bench_built_filter(&dense_filter, &data_array)),
);
c.bench_function(
"filter context mixed string view low selectivity (kept 1/1024)",
|b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)),
);
}

criterion_group!(benches, add_benchmark);
Expand Down

0 comments on commit e88e5aa

Please sign in to comment.