Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(arrow-select): concat kernel will merge dictionary values for list of dictionaries #6893

Merged
merged 16 commits into from
Jan 4, 2025
Merged
52 changes: 52 additions & 0 deletions arrow-buffer/src/buffer/offset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,38 @@ impl<O: ArrowNativeType> OffsetBuffer<O> {
Self(out.into())
}

/// Get an Iterator over the lengths of this [`OffsetBuffer`]
///
/// ```
/// # use arrow_buffer::{OffsetBuffer, ScalarBuffer};
/// let offsets = OffsetBuffer::<_>::new(ScalarBuffer::<i32>::from(vec![0, 1, 4, 9]));
/// assert_eq!(offsets.lengths().collect::<Vec<usize>>(), vec![1, 3, 5]);
/// ```
///
/// Empty [`OffsetBuffer`] will return an empty iterator
/// ```
/// # use arrow_buffer::OffsetBuffer;
/// let offsets = OffsetBuffer::<i32>::new_empty();
/// assert_eq!(offsets.lengths().count(), 0);
/// ```
///
/// This can be used to merge multiple [`OffsetBuffer`]s to one
/// ```
/// # use arrow_buffer::{OffsetBuffer, ScalarBuffer};
///
/// let buffer1 = OffsetBuffer::<i32>::from_lengths([2, 6, 3, 7, 2]);
/// let buffer2 = OffsetBuffer::<i32>::from_lengths([1, 3, 5, 7, 9]);
///
/// let merged = OffsetBuffer::<i32>::from_lengths(
/// vec![buffer1, buffer2].iter().flat_map(|x| x.lengths())
/// );
///
/// assert_eq!(merged.lengths().collect::<Vec<_>>(), &[2, 6, 3, 7, 2, 1, 3, 5, 7, 9]);
/// ```
pub fn lengths(&self) -> impl ExactSizeIterator<Item = usize> + '_ {
self.0.windows(2).map(|x| x[1].as_usize() - x[0].as_usize())
}

/// Free up unused memory.
pub fn shrink_to_fit(&mut self) {
self.0.shrink_to_fit();
Expand Down Expand Up @@ -244,4 +276,24 @@ mod tests {
fn from_lengths_usize_overflow() {
OffsetBuffer::<i32>::from_lengths([usize::MAX, 1]);
}

#[test]
fn get_lengths() {
let offsets = OffsetBuffer::<i32>::new(ScalarBuffer::<i32>::from(vec![0, 1, 4, 9]));
assert_eq!(offsets.lengths().collect::<Vec<usize>>(), vec![1, 3, 5]);
}

#[test]
fn get_lengths_should_be_with_fixed_size() {
let offsets = OffsetBuffer::<i32>::new(ScalarBuffer::<i32>::from(vec![0, 1, 4, 9]));
let iter = offsets.lengths();
assert_eq!(iter.size_hint(), (3, Some(3)));
assert_eq!(iter.len(), 3);
}

#[test]
fn get_lengths_from_empty_offset_buffer_should_be_empty_iterator() {
let offsets = OffsetBuffer::<i32>::new_empty();
assert_eq!(offsets.lengths().collect::<Vec<usize>>(), vec![]);
}
}
191 changes: 180 additions & 11 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values}
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer};
use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, OffsetBuffer};
use arrow_data::transform::{Capacities, MutableArrayData};
use arrow_schema::{ArrowError, DataType, SchemaRef};
use arrow_schema::{ArrowError, DataType, FieldRef, SchemaRef};
use std::sync::Arc;

fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
Expand Down Expand Up @@ -129,6 +129,54 @@ fn concat_dictionaries<K: ArrowDictionaryKeyType>(
Ok(Arc::new(array))
}

fn concat_lists<OffsetSize: OffsetSizeTrait>(
arrays: &[&dyn Array],
field: &FieldRef,
) -> Result<ArrayRef, ArrowError> {
let mut output_len = 0;
let mut list_has_nulls = false;

let lists = arrays
.iter()
.map(|x| x.as_list::<OffsetSize>())
.inspect(|l| {
output_len += l.len();
list_has_nulls |= l.null_count() != 0;
})
.collect::<Vec<_>>();

let lists_nulls = list_has_nulls.then(|| {
let mut nulls = BooleanBufferBuilder::new(output_len);
for l in &lists {
match l.nulls() {
Some(n) => nulls.append_buffer(n.inner()),
None => nulls.append_n(l.len(), true),
}
}
NullBuffer::new(nulls.finish())
});

let values: Vec<&dyn Array> = lists
.iter()
.map(|x| x.values().as_ref())
.collect::<Vec<_>>();

let concatenated_values = concat(values.as_slice())?;

// Merge value offsets from the lists
let value_offset_buffer =
OffsetBuffer::<OffsetSize>::from_lengths(lists.iter().flat_map(|x| x.offsets().lengths()));

let array = GenericListArray::<OffsetSize>::try_new(
Arc::clone(field),
value_offset_buffer,
concatenated_values,
lists_nulls,
)?;

Ok(Arc::new(array))
}

macro_rules! dict_helper {
($t:ty, $arrays:expr) => {
return Ok(Arc::new(concat_dictionaries::<$t>($arrays)?) as _)
Expand Down Expand Up @@ -163,14 +211,20 @@ pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
"It is not possible to concatenate arrays of different data types.".to_string(),
));
}
if let DataType::Dictionary(k, _) = d {
downcast_integer! {
k.as_ref() => (dict_helper, arrays),
_ => unreachable!("illegal dictionary key type {k}")
};
} else {
let capacity = get_capacity(arrays, d);
concat_fallback(arrays, capacity)

match d {
DataType::Dictionary(k, _) => {
downcast_integer! {
k.as_ref() => (dict_helper, arrays),
_ => unreachable!("illegal dictionary key type {k}")
}
}
DataType::List(field) => concat_lists::<i32>(arrays, field),
DataType::LargeList(field) => concat_lists::<i64>(arrays, field),
_ => {
let capacity = get_capacity(arrays, d);
concat_fallback(arrays, capacity)
}
}
}

Expand Down Expand Up @@ -228,8 +282,9 @@ pub fn concat_batches<'a>(
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::builder::StringDictionaryBuilder;
use arrow_array::builder::{GenericListBuilder, StringDictionaryBuilder};
use arrow_schema::{Field, Schema};
use std::fmt::Debug;

#[test]
fn test_concat_empty_vec() {
Expand Down Expand Up @@ -851,4 +906,118 @@ mod tests {
assert_eq!(array.null_count(), 10);
assert_eq!(array.logical_null_count(), 10);
}

#[test]
fn concat_dictionary_list_array_simple() {
let scalars = vec![
create_single_row_list_of_dict(vec![Some("a")]),
create_single_row_list_of_dict(vec![Some("a")]),
create_single_row_list_of_dict(vec![Some("b")]),
];

let arrays = scalars
.iter()
.map(|a| a as &(dyn Array))
.collect::<Vec<_>>();
let concat_res = concat(arrays.as_slice()).unwrap();

let expected_list = create_list_of_dict(vec![
// Row 1
Some(vec![Some("a")]),
Some(vec![Some("a")]),
Some(vec![Some("b")]),
]);

let list = concat_res.as_list::<i32>();

// Assert that the list is equal to the expected list
list.iter().zip(expected_list.iter()).for_each(|(a, b)| {
assert_eq!(a, b);
});

assert_dictionary_has_unique_values::<_, StringArray>(
list.values().as_dictionary::<Int32Type>(),
);
}

#[test]
fn concat_many_dictionary_list_arrays() {
let number_of_unique_values = 8;
let scalars = (0..80000)
.map(|i| {
create_single_row_list_of_dict(vec![Some(
(i % number_of_unique_values).to_string(),
)])
})
.collect::<Vec<_>>();

let arrays = scalars
.iter()
.map(|a| a as &(dyn Array))
.collect::<Vec<_>>();
let concat_res = concat(arrays.as_slice()).unwrap();

let expected_list = create_list_of_dict(
(0..80000)
.map(|i| Some(vec![Some((i % number_of_unique_values).to_string())]))
.collect::<Vec<_>>(),
);

let list = concat_res.as_list::<i32>();

// Assert that the list is equal to the expected list
list.iter().zip(expected_list.iter()).for_each(|(a, b)| {
assert_eq!(a, b);
});

assert_dictionary_has_unique_values::<_, StringArray>(
rluvaton marked this conversation as resolved.
Show resolved Hide resolved
list.values().as_dictionary::<Int32Type>(),
);
}

fn create_single_row_list_of_dict(
list_items: Vec<Option<impl AsRef<str>>>,
) -> GenericListArray<i32> {
let rows = list_items.into_iter().map(Some).collect();

create_list_of_dict(vec![rows])
}

fn create_list_of_dict(
rows: Vec<Option<Vec<Option<impl AsRef<str>>>>>,
) -> GenericListArray<i32> {
let mut builder =
GenericListBuilder::<i32, _>::new(StringDictionaryBuilder::<Int32Type>::new());

for row in rows {
builder.append_option(row);
}

builder.finish()
}

fn assert_dictionary_has_unique_values<'a, K, V>(array: &'a DictionaryArray<K>)
where
K: ArrowDictionaryKeyType,
V: Sync + Send + 'static,
&'a V: ArrayAccessor + IntoIterator,

<&'a V as ArrayAccessor>::Item: Default + Clone + PartialEq + Debug + Ord,
<&'a V as IntoIterator>::Item: Clone + PartialEq + Debug + Ord,
{
let dict = array.downcast_dict::<V>().unwrap();
let mut values = dict.values().into_iter().collect::<Vec<_>>();

// remove duplicates must be sorted first so we can compare
values.sort();

let mut unique_values = values.clone();

unique_values.dedup();

assert_eq!(
values, unique_values,
"There are duplicates in the value list (the value list here is sorted which is only for the assertion)"
);
}
}
Loading