diff --git a/parquet/src/arrow/arrow_reader/selection.rs b/parquet/src/arrow/arrow_reader/selection.rs index 0287e5b42158..ce3fbbf4f091 100644 --- a/parquet/src/arrow/arrow_reader/selection.rs +++ b/parquet/src/arrow/arrow_reader/selection.rs @@ -343,6 +343,16 @@ impl RowSelection { intersect_row_selections(&self.selectors, &other.selectors) } + /// Compute the union of two [`RowSelection`] + /// For example: + /// self: NNYYYYNNYYNYN + /// other: NYNNNNNNN + /// + /// returned: NYYYYYNNYYNYN + pub fn union(&self, other: &Self) -> Self { + union_row_selections(&self.selectors, &other.selectors) + } + /// Returns `true` if this [`RowSelection`] selects any rows pub fn selects_any(&self) -> bool { self.selectors.iter().any(|x| !x.skip) @@ -536,6 +546,92 @@ fn intersect_row_selections(left: &[RowSelector], right: &[RowSelector]) -> RowS iter.collect() } +/// Combine two lists of `RowSelector` return the union of them +/// For example: +/// self: NNYYYYNNYYNYN +/// other: NYNNNNNNY +/// +/// returned: NYYYYYNNYYNYN +/// +/// This can be removed from here once RowSelection::union is in parquet::arrow +fn union_row_selections(left: &[RowSelector], right: &[RowSelector]) -> RowSelection { + let mut l_iter = left.iter().copied().peekable(); + let mut r_iter = right.iter().copied().peekable(); + + let iter = std::iter::from_fn(move || { + loop { + let l = l_iter.peek_mut(); + let r = r_iter.peek_mut(); + + match (l, r) { + (Some(a), _) if a.row_count == 0 => { + l_iter.next().unwrap(); + } + (_, Some(b)) if b.row_count == 0 => { + r_iter.next().unwrap(); + } + (Some(l), Some(r)) => { + return match (l.skip, r.skip) { + // Skip both ranges + (true, true) => { + if l.row_count < r.row_count { + let skip = l.row_count; + r.row_count -= l.row_count; + l_iter.next(); + Some(RowSelector::skip(skip)) + } else { + let skip = r.row_count; + l.row_count -= skip; + r_iter.next(); + Some(RowSelector::skip(skip)) + } + } + // Keep rows from left + (false, true) => { + if l.row_count < r.row_count { + r.row_count -= l.row_count; + l_iter.next() + } else { + let r_row_count = r.row_count; + l.row_count -= r_row_count; + r_iter.next(); + Some(RowSelector::select(r_row_count)) + } + } + // Keep rows from right + (true, false) => { + if l.row_count < r.row_count { + let l_row_count = l.row_count; + r.row_count -= l_row_count; + l_iter.next(); + Some(RowSelector::select(l_row_count)) + } else { + l.row_count -= r.row_count; + r_iter.next() + } + } + // Keep at least one + _ => { + if l.row_count < r.row_count { + r.row_count -= l.row_count; + l_iter.next() + } else { + l.row_count -= r.row_count; + r_iter.next() + } + } + }; + } + (Some(_), None) => return l_iter.next(), + (None, Some(_)) => return r_iter.next(), + (None, None) => return None, + } + } + }); + + iter.collect() +} + #[cfg(test)] mod tests { use super::*; @@ -1213,4 +1309,40 @@ mod tests { ] ); } + + #[test] + fn test_union() { + let selection = RowSelection::from(vec![RowSelector::select(1048576)]); + let result = selection.union(&selection); + assert_eq!(result, selection); + + // NYNYY + let a = RowSelection::from(vec![ + RowSelector::skip(10), + RowSelector::select(10), + RowSelector::skip(10), + RowSelector::select(20), + ]); + + // NNYYNYN + let b = RowSelection::from(vec![ + RowSelector::skip(20), + RowSelector::select(20), + RowSelector::skip(10), + RowSelector::select(10), + RowSelector::skip(10), + ]); + + let result = a.union(&b); + + // NYYYYYN + assert_eq!( + result.iter().collect::>(), + vec![ + &RowSelector::skip(10), + &RowSelector::select(50), + &RowSelector::skip(10), + ] + ); + } }