Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(arrow-select): concat kernel will merge dictionary values for list of dictionaries #6893

Merged
merged 16 commits into from
Jan 4, 2025
336 changes: 334 additions & 2 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values}
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer};
use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, Buffer, NullBuffer, OffsetBuffer};
use arrow_data::transform::{Capacities, MutableArrayData};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{ArrowError, DataType, SchemaRef};
use num::Saturating;
use std::sync::Arc;

fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
Expand Down Expand Up @@ -129,12 +131,161 @@ fn concat_dictionaries<K: ArrowDictionaryKeyType>(
Ok(Arc::new(array))
}

fn concat_list_of_dictionaries<OffsetSize: OffsetSizeTrait, K: ArrowDictionaryKeyType>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would significantly reduce the codegen to instead split concatenating the dictionary values, from re-encoding the offsets. As an added bonus this could also be done recursively.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated, please let me know it that was what you meant, I think it will be slower this way, no?

arrays: &[&dyn Array],
) -> Result<ArrayRef, ArrowError> {
let mut output_len = 0;
let mut list_has_nulls = false;

let lists = arrays
.iter()
.map(|x| x.as_list::<OffsetSize>())
.inspect(|l| {
output_len += l.len();
list_has_nulls |= l.null_count() != 0;
})
.collect::<Vec<_>>();

let mut dictionary_output_len = 0;
let dictionaries: Vec<_> = lists
.iter()
.map(|x| x.values().as_ref().as_dictionary::<K>())
.inspect(|d| dictionary_output_len += d.len())
.collect();

if !should_merge_dictionary_values::<K>(&dictionaries, dictionary_output_len) {
return concat_fallback(arrays, Capacities::Array(output_len));
}

let merged = merge_dictionary_values(&dictionaries, None)?;

let lists_nulls = list_has_nulls.then(|| {
let mut nulls = BooleanBufferBuilder::new(output_len);
for l in &lists {
match l.nulls() {
Some(n) => nulls.append_buffer(n.inner()),
None => nulls.append_n(l.len(), true),
}
}
NullBuffer::new(nulls.finish())
});

// Recompute keys
let mut key_values = Vec::with_capacity(dictionary_output_len);

let mut dictionary_has_nulls = false;
for (d, mapping) in dictionaries.iter().zip(merged.key_mappings) {
dictionary_has_nulls |= d.null_count() != 0;
for key in d.keys().values() {
// Use get to safely handle nulls
key_values.push(mapping.get(key.as_usize()).copied().unwrap_or_default())
}
}

let dictionary_nulls = dictionary_has_nulls.then(|| {
let mut nulls = BooleanBufferBuilder::new(dictionary_output_len);
for d in &dictionaries {
match d.nulls() {
Some(n) => nulls.append_buffer(n.inner()),
None => nulls.append_n(d.len(), true),
}
}
NullBuffer::new(nulls.finish())
});

let keys = PrimitiveArray::<K>::new(key_values.into(), dictionary_nulls);
// Sanity check
assert_eq!(keys.len(), dictionary_output_len);

let array = unsafe { DictionaryArray::new_unchecked(keys, merged.values) };

// Merge value offsets from the lists
let all_value_offsets_iterator = lists
.iter()
.map(|x| x.offsets());

let value_offset_buffer = merge_value_offsets(all_value_offsets_iterator);

let builder = ArrayDataBuilder::new(arrays[0].data_type().clone())
.len(output_len)
.nulls(lists_nulls)
// `GenericListArray` must only have 1 buffer
.buffers(vec![value_offset_buffer])
// `GenericListArray` must only have 1 child_data
.child_data(vec![array.to_data()]);

// TODO - maybe use build_unchecked?
let array_data = builder.build()?;
rluvaton marked this conversation as resolved.
Show resolved Hide resolved

let array = GenericListArray::<OffsetSize>::from(array_data);
Ok(Arc::new(array))
}

/// Merge value offsets
///
///
/// if we have the following
/// [[0, 3, 5], [0, 2, 2, 8], [], [0, 0, 1]]
/// The output should be
/// [ 0, 3, 5, 7, 7, 13, 13, 14]
fn merge_value_offsets<'a, OffsetSize: OffsetSizeTrait, I: Iterator<Item = &'a OffsetBuffer<OffsetSize>>>(offset_buffers_iterator: I) -> Buffer {
// 1. Filter out empty lists
let mut offset_buffers_iterator = offset_buffers_iterator.filter(|x| !x.is_empty());

// 2. Get first non-empty list as the starting point
let starting_buffer = offset_buffers_iterator.next();

// 3. If we have only empty lists, return an empty buffer
if starting_buffer.is_none() {
return Buffer::from(&[])
}

let starting_buffer = starting_buffer.unwrap();

let mut offsets_iter: Box<dyn Iterator<Item=OffsetSize>> = Box::new(starting_buffer.iter().copied());

// 4. Get the last value in the starting buffer as the starting point for the next buffer
// Safety: We already filtered out empty lists
let mut advance_by = *starting_buffer.last().unwrap();

// 5. Iterate over the remaining buffers
for offset_buffer in offset_buffers_iterator {
// 6. Get the last value of the current buffer so we can know how much to advance the next buffer
// Safety: We already filtered out empty lists
let last_value = *offset_buffer.last().unwrap();

// 7. Advance the offset buffer by the last value in the previous buffer
let offset_buffer_iter = offset_buffer
.iter()
// Skip the first value as it is the initial offset of 0
.skip(1)
.map(move |&x| x + advance_by);

// 8. concat the current buffer with the previous buffer
// Chaining keeps the iterator have trusting length
offsets_iter = Box::new(offsets_iter.chain(offset_buffer_iter));
rluvaton marked this conversation as resolved.
Show resolved Hide resolved

// 9. Update the next advance_by
advance_by += last_value;
}

unsafe {
Buffer::from_trusted_len_iter(offsets_iter)
}
}

macro_rules! dict_helper {
($t:ty, $arrays:expr) => {
return Ok(Arc::new(concat_dictionaries::<$t>($arrays)?) as _)
};
}

macro_rules! list_dict_helper {
($t:ty, $o: ty, $arrays:expr) => {
return Ok(Arc::new(concat_list_of_dictionaries::<$o, $t>($arrays)?) as _)
};
}

fn get_capacity(arrays: &[&dyn Array], data_type: &DataType) -> Capacities {
match data_type {
DataType::Utf8 => binary_capacity::<Utf8Type>(arrays),
Expand Down Expand Up @@ -169,6 +320,21 @@ pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
_ => unreachable!("illegal dictionary key type {k}")
};
} else {
if let DataType::List(field) = d {
rluvaton marked this conversation as resolved.
Show resolved Hide resolved
if let DataType::Dictionary(k, _) = field.data_type() {
return downcast_integer! {
k.as_ref() => (list_dict_helper, i32, arrays),
_ => unreachable!("illegal dictionary key type {k}")
};
}
} else if let DataType::LargeList(field) = d {
if let DataType::Dictionary(k, _) = field.data_type() {
return downcast_integer! {
k.as_ref() => (list_dict_helper, i64, arrays),
_ => unreachable!("illegal dictionary key type {k}")
};
}
}
let capacity = get_capacity(arrays, d);
concat_fallback(arrays, capacity)
}
Expand Down Expand Up @@ -228,8 +394,9 @@ pub fn concat_batches<'a>(
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::builder::StringDictionaryBuilder;
use arrow_array::builder::{GenericListBuilder, StringDictionaryBuilder};
use arrow_schema::{Field, Schema};
use std::fmt::Debug;

#[test]
fn test_concat_empty_vec() {
Expand Down Expand Up @@ -851,4 +1018,169 @@ mod tests {
assert_eq!(array.null_count(), 10);
assert_eq!(array.logical_null_count(), 10);
}

#[test]
fn concat_dictionary_list_array_simple() {
let scalars = vec![
create_single_row_list_of_dict(vec![Some("a")]),
create_single_row_list_of_dict(vec![Some("a")]),
create_single_row_list_of_dict(vec![Some("b")]),
];

let arrays = scalars.iter().map(|a| a as &(dyn Array)).collect::<Vec<_>>();
let concat_res = concat(arrays.as_slice()).unwrap();

let expected_list = create_list_of_dict(vec![
// Row 1
Some(vec![Some("a")]),
Some(vec![Some("a")]),
Some(vec![Some("b")]),
]);

let list = concat_res.as_list::<i32>();

// Assert that the list is equal to the expected list
list.iter().zip(expected_list.iter()).for_each(|(a, b)| {
assert_eq!(a, b);
});

let dict = list
.values()
.as_dictionary::<Int32Type>()
.downcast_dict::<StringArray>()
.unwrap();
println!("{:?}", dict);

assert_dictionary_has_unique_values::<_, StringArray>(
list.values().as_dictionary::<Int32Type>(),
);
}

#[test]
fn concat_dictionary_list_array_with_multiple_rows() {
rluvaton marked this conversation as resolved.
Show resolved Hide resolved
let scalars = vec![
create_list_of_dict(vec![
// Row 1
Some(vec![Some("a"), Some("c")]),
// Row 2
None,
// Row 3
Some(vec![Some("f"), Some("g"), None]),
// Row 4
Some(vec![Some("c"), Some("f")]),
]),
create_list_of_dict(vec![
// Row 1
Some(vec![Some("a")]),
// Row 2
Some(vec![]),
// Row 3
Some(vec![None, Some("b")]),
// Row 4
Some(vec![Some("d"), Some("e")]),
]),
create_list_of_dict(vec![
// Row 1
Some(vec![Some("g")]),
// Row 2
Some(vec![Some("h"), Some("i")]),
// Row 3
Some(vec![Some("j"), Some("a")]),
// Row 4
Some(vec![Some("d"), Some("e")]),
]),
];
let arrays = scalars
.iter()
.map(|a| a as &(dyn Array))
.collect::<Vec<_>>();
let concat_res = concat(arrays.as_slice()).unwrap();

let expected_list = create_list_of_dict(vec![
// First list:

// Row 1
Some(vec![Some("a"), Some("c")]),
// Row 2
None,
// Row 3
Some(vec![Some("f"), Some("g"), None]),
// Row 4
Some(vec![Some("c"), Some("f")]),
// Second list:
// Row 1
Some(vec![Some("a")]),
// Row 2
Some(vec![]),
// Row 3
Some(vec![None, Some("b")]),
// Row 4
Some(vec![Some("d"), Some("e")]),
// Third list:

// Row 1
Some(vec![Some("g")]),
// Row 2
Some(vec![Some("h"), Some("i")]),
// Row 3
Some(vec![Some("j"), Some("a")]),
// Row 4
Some(vec![Some("d"), Some("e")]),
]);

let list = concat_res.as_list::<i32>();

// Assert that the list is equal to the expected list
list.iter().zip(expected_list.iter()).for_each(|(a, b)| {
assert_eq!(a, b);
});

// Assert that the
assert_dictionary_has_unique_values::<_, StringArray>(
rluvaton marked this conversation as resolved.
Show resolved Hide resolved
list.values().as_dictionary::<Int32Type>(),
);
}

fn create_single_row_list_of_dict(list_items: Vec<Option<&'static str>>) -> GenericListArray<i32> {
let rows = list_items.into_iter().map(|row| Some(row)).collect();

create_list_of_dict(vec![rows])
}

fn create_list_of_dict(rows: Vec<Option<Vec<Option<&'static str>>>>) -> GenericListArray<i32> {
let mut builder =
GenericListBuilder::<i32, _>::new(StringDictionaryBuilder::<Int32Type>::new());

for row in rows {
builder.append_option(row);
}

builder.finish()
}

fn assert_dictionary_has_unique_values<'a, K, V: 'static>(
array: &'a DictionaryArray<K>,
) where
K: ArrowDictionaryKeyType,
V: Sync + Send,
&'a V: ArrayAccessor + IntoIterator,

<&'a V as ArrayAccessor>::Item: Default + Clone + PartialEq + Debug + Ord,
<&'a V as IntoIterator>::Item: Clone + PartialEq + Debug + Ord,
{
let dict = array.downcast_dict::<V>().unwrap();
let mut values = dict.values().clone().into_iter().collect::<Vec<_>>();

// remove duplicates must be sorted first so we can compare
values.sort();

let mut unique_values = values.clone();

unique_values.dedup();

assert_eq!(
values, unique_values,
"There are duplicates in the value list (the value list here is sorted which is only for the assertion)"
);
}
}
Loading