From 0fd97dcf1b8aef962a061f0376f32556f3a3c357 Mon Sep 17 00:00:00 2001 From: Judah Rand <17158624+judahrand@users.noreply.github.com> Date: Mon, 10 Jun 2024 16:56:45 +0100 Subject: [PATCH] Get capacity recursively for `FixedSizeList` --- arrow-select/src/concat.rs | 53 ++++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index 3db3476289b4..75bdf5cb7bed 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -54,6 +54,22 @@ fn binary_capacity(arrays: &[&dyn Array]) -> Capacities { Capacities::Binary(item_capacity, Some(bytes_capacity)) } +fn fixed_size_list_capacity(arrays: &[&dyn Array], data_type: &DataType) -> Capacities { + if let DataType::FixedSizeList(f, _) = data_type { + let item_capacity = arrays.iter().map(|a| a.len()).sum(); + let values: Vec<&dyn arrow_array::Array> = arrays + .iter() + .map(|a| a.as_fixed_size_list().values().as_ref()) + .collect(); + Capacities::FixedSizeList( + item_capacity, + Some(Box::new(get_capacity(&values, f.data_type()))), + ) + } else { + unreachable!("illegal data type for fixed size list") + } +} + fn concat_dictionaries( arrays: &[&dyn Array], ) -> Result { @@ -107,6 +123,17 @@ macro_rules! dict_helper { }; } +fn get_capacity(arrays: &[&dyn Array], data_type: &DataType) -> Capacities { + match data_type { + DataType::Utf8 => binary_capacity::(arrays), + DataType::LargeUtf8 => binary_capacity::(arrays), + DataType::Binary => binary_capacity::(arrays), + DataType::LargeBinary => binary_capacity::(arrays), + DataType::FixedSizeList(_, _) => fixed_size_list_capacity(arrays, data_type), + _ => Capacities::Array(arrays.iter().map(|a| a.len()).sum()), + } +} + /// Concatenate multiple [Array] of the same type into a single [ArrayRef]. pub fn concat(arrays: &[&dyn Array]) -> Result { if arrays.is_empty() { @@ -124,27 +151,15 @@ pub fn concat(arrays: &[&dyn Array]) -> Result { "It is not possible to concatenate arrays of different data types.".to_string(), )); } - - let capacity = match d { - DataType::Utf8 => binary_capacity::(arrays), - DataType::LargeUtf8 => binary_capacity::(arrays), - DataType::Binary => binary_capacity::(arrays), - DataType::LargeBinary => binary_capacity::(arrays), - DataType::Dictionary(k, _) => downcast_integer! { + if let DataType::Dictionary(k, _) = d { + downcast_integer! { k.as_ref() => (dict_helper, arrays), _ => unreachable!("illegal dictionary key type {k}") - }, - DataType::FixedSizeList(_, size) => { - let item_capacity = arrays.iter().map(|a| a.len()).sum(); - Capacities::FixedSizeList( - item_capacity, - Some(Box::new(Capacities::Array(item_capacity * *size as usize))), - ) - } - _ => Capacities::Array(arrays.iter().map(|a| a.len()).sum()), - }; - - concat_fallback(arrays, capacity) + }; + } else { + let capacity = get_capacity(arrays, d); + concat_fallback(arrays, capacity) + } } /// Concatenates arrays using MutableArrayData