From d357df7d7ed78e5142ed1ad2ae21fd0fff1aa0f5 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Sat, 31 Aug 2024 11:14:56 +0800 Subject: [PATCH 01/35] feat: scalar regex match physical expr --- .../physical-expr/src/expressions/mod.rs | 2 + .../src/expressions/scalar_regex_match.rs | 670 ++++++++++++++++++ datafusion/physical-expr/src/planner.rs | 28 +- datafusion/proto/proto/datafusion.proto | 9 + datafusion/proto/src/generated/pbjson.rs | 157 ++++ datafusion/proto/src/generated/prost.rs | 17 +- .../proto/src/physical_plan/from_proto.rs | 22 +- .../proto/src/physical_plan/to_proto.rs | 21 +- 8 files changed, 922 insertions(+), 4 deletions(-) create mode 100644 datafusion/physical-expr/src/expressions/scalar_regex_match.rs diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index f00b49f50314..462236737074 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -30,6 +30,7 @@ mod literal; mod negative; mod no_op; mod not; +mod scalar_regex_match; mod try_cast; mod unknown_column; @@ -50,5 +51,6 @@ pub use literal::{lit, Literal}; pub use negative::{negative, NegativeExpr}; pub use no_op::NoOp; pub use not::{not, NotExpr}; +pub use scalar_regex_match::{scalar_regex_match, ScalarRegexMatchExpr}; pub use try_cast::{try_cast, TryCastExpr}; pub use unknown_column::UnKnownColumn; diff --git a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs new file mode 100644 index 000000000000..badb00659576 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs @@ -0,0 +1,670 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::Literal; +use arrow::array::ArrayData; +use arrow_array::{ + Array, ArrayAccessor, BooleanArray, LargeStringArray, StringArray, StringViewArray, +}; +use arrow_buffer::BooleanBufferBuilder; +use arrow_schema::{DataType, Schema}; +use datafusion_common::ScalarValue; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::physical_expr::{down_cast_any_ref, PhysicalExpr}; +use regex::Regex; +use std::{any::Any, hash::Hash, sync::Arc}; + +/// ScalarRegexMatchExpr +/// Only used when evaluating regexp matching with literal pattern. +/// Example regex expression: c1 ~ '^a' / c1 !~ '^a' / c1 ~* '^a' / c1 !~* '^a'. +/// Literal regexp pattern will be compiled once and cached to be reused in execution. +/// It's will save compile time of pre execution and speed up execution. +#[derive(Clone)] +pub struct ScalarRegexMatchExpr { + negated: bool, + case_insensitive: bool, + expr: Arc, + pattern: Arc, + compiled: Option, +} + +impl ScalarRegexMatchExpr { + pub fn new( + negated: bool, + case_insensitive: bool, + expr: Arc, + pattern: Arc, + ) -> Self { + let mut res = Self { + negated, + case_insensitive, + expr, + pattern, + compiled: None, + }; + res.compile().unwrap(); + res + } + + /// Is negated + pub fn negated(&self) -> bool { + self.negated + } + + /// Is case insensitive + pub fn case_insensitive(&self) -> bool { + self.case_insensitive + } + + /// Input expression + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Pattern expression + pub fn pattern(&self) -> &Arc { + &self.pattern + } + + /// Compile regex pattern + fn compile(&mut self) -> datafusion_common::Result<()> { + let scalar_pattern = + self.pattern + .as_any() + .downcast_ref::() + .and_then(|pattern| match pattern.value() { + ScalarValue::Null + | ScalarValue::Utf8(None) + | ScalarValue::Utf8View(None) + | ScalarValue::LargeUtf8(None) => Some(None), + ScalarValue::Utf8(Some(pattern)) + | ScalarValue::Utf8View(Some(pattern)) + | ScalarValue::LargeUtf8(Some(pattern)) => { + let mut pattern = pattern.to_string(); + if self.case_insensitive { + pattern = format!("(?i){}", pattern); + } + Some(Some(pattern)) + } + _ => None, + }); + match scalar_pattern { + Some(Some(scalar_pattern)) => Regex::new(scalar_pattern.as_str()) + .map(|compiled| { + self.compiled = Some(compiled); + }) + .map_err(|err| { + datafusion_common::DataFusionError::Internal(format!( + "Failed to compile regex: {}", + err + )) + }), + Some(None) => { + self.compiled = None; + Ok(()) + } + None => Err(datafusion_common::DataFusionError::Internal(format!( + "Regex pattern({}) isn't literal string", + self.pattern + ))), + } + } + + /// Operator name + fn op_name(&self) -> &str { + match (self.negated, self.case_insensitive) { + (false, false) => "MATCH", + (true, false) => "NOT MATCH", + (false, true) => "IMATCH", + (true, true) => "NOT IMATCH", + } + } +} + +impl ScalarRegexMatchExpr { + /// Evaluate the scalar regex match expression match array value + fn evaluate_array( + &self, + array: &Arc, + ) -> datafusion_common::Result { + macro_rules! downcast_string_array { + ($ARRAY:expr, $ARRAY_TYPE:ident, $ERR_MSG:expr) => { + &($ARRAY + .as_any() + .downcast_ref::<$ARRAY_TYPE>() + .expect($ERR_MSG)) + }; + } + match array.data_type() { + DataType::Null => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))) + }, + DataType::Utf8 => array_regexp_match( + downcast_string_array!(array, StringArray, "Failed to downcast StringArray"), + self.compiled.as_ref().unwrap(), + self.negated, + ), + DataType::Utf8View => array_regexp_match( + downcast_string_array!(array, StringViewArray, "Failed to downcast StringViewArray"), + self.compiled.as_ref().unwrap(), + self.negated, + ), + DataType::LargeUtf8 => array_regexp_match( + downcast_string_array!(array, LargeStringArray, "Failed to downcast LargeStringArray"), + self.compiled.as_ref().unwrap(), + self.negated, + ), + other=> datafusion_common::internal_err!( + "Data type {:?} not supported for ScalarRegexMatchExpr, expect Utf8|Utf8View|LargeUtf8", other + ), + } + } + + /// Evaluate the scalar regex match expression match scalar value + fn evaluate_scalar( + &self, + scalar: &ScalarValue, + ) -> datafusion_common::Result { + match scalar { + ScalarValue::Null + | ScalarValue::Utf8(None) + | ScalarValue::Utf8View(None) + | ScalarValue::LargeUtf8(None) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))), + ScalarValue::Utf8(Some(scalar)) + | ScalarValue::Utf8View(Some(scalar)) + | ScalarValue::LargeUtf8(Some(scalar)) => { + let mut result = self.compiled.as_ref().unwrap().is_match(scalar); + if self.negated { + result = !result; + } + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(result)))) + }, + other=> datafusion_common::internal_err!( + "Data type {:?} not supported for ScalarRegexMatchExpr, expect Utf8|Utf8View|LargeUtf8", other + ), + } + } +} + +impl std::hash::Hash for ScalarRegexMatchExpr { + fn hash(&self, state: &mut H) { + self.negated.hash(state); + self.case_insensitive.hash(state); + self.expr.hash(state); + self.pattern.hash(state); + } +} + +impl std::cmp::PartialEq for ScalarRegexMatchExpr { + fn eq(&self, other: &Self) -> bool { + self.negated.eq(&other.negated) + && self.case_insensitive.eq(&self.case_insensitive) + && self.expr.eq(&other.expr) + && self.pattern.eq(&other.pattern) + } +} + +impl std::fmt::Debug for ScalarRegexMatchExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ScalarRegexMatchExpr") + .field("negated", &self.negated) + .field("case_insensitive", &self.case_insensitive) + .field("expr", &self.expr) + .field("pattern", &self.pattern) + .finish() + } +} + +impl std::fmt::Display for ScalarRegexMatchExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{} {} {}", self.expr, self.op_name(), self.pattern) + } +} + +impl PhysicalExpr for ScalarRegexMatchExpr { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn data_type( + &self, + _: &arrow_schema::Schema, + ) -> datafusion_common::Result { + Ok(DataType::Boolean) + } + + fn nullable( + &self, + input_schema: &arrow_schema::Schema, + ) -> datafusion_common::Result { + Ok(self.expr.nullable(input_schema)? || self.pattern.nullable(input_schema)?) + } + + fn evaluate( + &self, + batch: &arrow_array::RecordBatch, + ) -> datafusion_common::Result { + self.expr + .evaluate(batch) + .and_then(|lhs| { + if self.compiled.is_some() { + match &lhs { + ColumnarValue::Array(array) => self.evaluate_array(array), + ColumnarValue::Scalar(scalar) => self.evaluate_scalar(scalar), + } + } else { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))) + } + }) + .and_then(|result| result.into_array(batch.num_rows())) + .map(ColumnarValue::Array) + } + + fn children(&self) -> Vec<&std::sync::Arc> { + vec![&self.expr, &self.pattern] + } + + fn with_new_children( + self: std::sync::Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(ScalarRegexMatchExpr::new( + self.negated, + self.case_insensitive, + Arc::clone(&children[0]), + Arc::clone(&children[1]), + ))) + } + + fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) { + let mut s = state; + self.hash(&mut s); + } +} + +impl PartialEq for ScalarRegexMatchExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self == x) + .unwrap_or(false) + } +} + +/// It is used for scalar regexp matching and copy from arrow-rs +fn array_regexp_match( + array: &dyn ArrayAccessor, + regex: &Regex, + negated: bool, +) -> datafusion_common::Result { + let null_bit_buffer = array.nulls().map(|x| x.inner().sliced()); + let mut buffer_builder = BooleanBufferBuilder::new(array.len()); + + if regex.as_str().is_empty() { + buffer_builder.append_n(array.len(), true); + } else { + for i in 0..array.len() { + let value = array.value(i); + buffer_builder.append(regex.is_match(value)); + } + } + + let buffer = buffer_builder.into(); + let bool_array = BooleanArray::from(unsafe { + ArrayData::new_unchecked( + DataType::Boolean, + array.len(), + None, + null_bit_buffer, + 0, + vec![buffer], + vec![], + ) + }); + + let bool_array = if negated { + arrow::compute::kernels::boolean::not(&bool_array) + } else { + Ok(bool_array) + }; + + bool_array + .map_err(|err| { + datafusion_common::DataFusionError::Execution(format!( + "Failed to evaluate regex: {}", + err + )) + }) + .map(|bool_array| ColumnarValue::Array(Arc::new(bool_array))) +} + +/// Create a scalar regex match expression, erroring if the argument types are not compatible. +pub fn scalar_regex_match( + negated: bool, + case_insensitive: bool, + expr: Arc, + pattern: Arc, + input_schema: &Schema, +) -> datafusion_common::Result> { + let valid_data_type = |data_type: &DataType| { + if !matches!( + data_type, + DataType::Null | DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 + ) { + return datafusion_common::internal_err!( + "The type {data_type} not supported for scalar_regex_match, expect Null|Utf8|Utf8View|LargeUtf8" + ); + } + Ok(()) + }; + + for arg_expr in [&expr, &pattern] { + arg_expr + .data_type(input_schema) + .and_then(|data_type| valid_data_type(&data_type))?; + } + + Ok(Arc::new(ScalarRegexMatchExpr::new( + negated, + case_insensitive, + expr, + pattern, + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{col, lit}; + use arrow::record_batch::RecordBatch; + use arrow_array::ArrayRef; + use arrow_array::NullArray; + use arrow_schema::Field; + use arrow_schema::Schema; + use rstest::rstest; + use std::sync::Arc; + + fn test_schema(typ: DataType) -> Schema { + Schema::new(vec![Field::new("c1", typ, false)]) + } + + #[rstest( + negated, case_insensitive, typ, a_vec, b_lit, c_vec, + case( + false, false, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + "^a", + Arc::new(BooleanArray::from(vec![true, false, false, false, false])), + ), + case( + false, true, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + "^a", + Arc::new(BooleanArray::from(vec![true, false, true, false, false])), + ), + case( + true, false, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + "^a", + Arc::new(BooleanArray::from(vec![false, true, true, true, true])), + ), + case( + true, true, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + "^a", + Arc::new(BooleanArray::from(vec![false, true, false, true, true])), + ), + case( + true, true, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8(None), + Arc::new(BooleanArray::from(vec![None, None, None, None, None])), + ), + case( + false, false, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![true, false, false, false, false])), + ), + case( + false, true, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![true, false, true, false, false])), + ), + case( + true, false, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![false, true, true, true, true])), + ), + case( + true, true, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![false, true, false, true, true])), + ), + case( + true, true, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(None), + Arc::new(BooleanArray::from(vec![None, None, None, None, None])), + ), + case( + false, false, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![true, false, false, false, false])), + ), + case( + false, true, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![true, false, true, false, false])), + ), + case( + true, false, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![false, true, true, true, true])), + ), + case( + true, true, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![false, true, false, true, true])), + ), + case( + true, true, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(None), + Arc::new(BooleanArray::from(vec![None, None, None, None, None])), + ), + case( + true, true, DataType::Null, + Arc::new(NullArray::new(5)), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![None, None, None, None, None])), + ), + )] + fn test_scalar_regex_match_array( + negated: bool, + case_insensitive: bool, + typ: DataType, + a_vec: ArrayRef, + b_lit: impl datafusion_expr::Literal, + c_vec: ArrayRef, + ) { + let schema = test_schema(typ); + let left = col("c1", &schema).unwrap(); + let right = lit(b_lit); + + // verify that we can construct the expression + let expression = + scalar_regex_match(negated, case_insensitive, left, right, &schema).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a_vec]).unwrap(); + + // verify that the expression's type is correct + assert_eq!(expression.data_type(&schema).unwrap(), DataType::Boolean); + + // compute + let result = expression + .evaluate(&batch) + .expect("Error evaluating expression"); + + if let ColumnarValue::Array(array) = result { + let array = array + .as_any() + .downcast_ref::() + .expect("failed to downcast to BooleanArray"); + + let c_vec = c_vec + .as_any() + .downcast_ref::() + .expect("failed to downcast to BooleanArray"); + // verify that the result is correct + assert_eq!(array, c_vec); + } else { + panic!("result was not an array"); + } + } + + #[rstest( + negated, case_insensitive, typ, a_lit, b_lit, flag, + case( + false, false, DataType::Utf8, "abc", "^a", Some(true), + ), + case( + false, true, DataType::Utf8, "Abc", "^a", Some(true), + ), + case( + true, false, DataType::Utf8, "abc", "^a", Some(false), + ), + case( + true, true, DataType::Utf8, "Abc", "^a", Some(false), + ), + case( + true, true, DataType::Utf8, + ScalarValue::Utf8(Some("Abc".to_string())), + ScalarValue::Utf8(None), + None, + ), + case( + false, false, DataType::Utf8, + ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Some(true), + ), + case( + false, true, DataType::Utf8, + ScalarValue::Utf8(Some("Abc".to_string())), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Some(true), + ), + case( + true, false, DataType::Utf8, + ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Some(false), + ), + case( + true, true, DataType::Utf8, + ScalarValue::Utf8(Some("Abc".to_string())), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Some(false), + ), + case( + true, true, DataType::Utf8, + ScalarValue::Utf8(Some("Abc".to_string())), + ScalarValue::LargeUtf8(None), + None, + ), + )] + fn test_scalar_regex_match_scalar( + negated: bool, + case_insensitive: bool, + typ: DataType, + a_lit: impl datafusion_expr::Literal, + b_lit: impl datafusion_expr::Literal, + flag: Option, + ) { + let left = lit(a_lit); + let right = lit(b_lit); + let schema = test_schema(typ); + let expression = + scalar_regex_match(negated, case_insensitive, left, right, &schema).unwrap(); + let num_rows: usize = 3; + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StringArray::from([""].repeat(num_rows)))], + ) + .unwrap(); + + // verify that the expression's type is correct + assert_eq!(expression.data_type(&schema).unwrap(), DataType::Boolean); + + // compute + let result = expression + .evaluate(&batch) + .expect("Error evaluating expression"); + + if let ColumnarValue::Array(array) = result { + let array = array + .as_any() + .downcast_ref::() + .expect("failed to downcast to BooleanArray"); + + // verify that the result is correct + let c_vec = [flag].repeat(batch.num_rows()); + assert_eq!(array, &BooleanArray::from(c_vec)); + } else { + panic!("result was not an array"); + } + } + + #[rstest( + expr, pattern, + case( + col("c1", &test_schema(DataType::Utf8)).unwrap(), + lit(1), + ), + case( + lit(1), + col("c1", &test_schema(DataType::Utf8)).unwrap(), + ), + )] + #[should_panic] + fn test_scalar_regex_match_panic( + expr: Arc, + pattern: Arc, + ) { + let _ = + scalar_regex_match(false, false, expr, pattern, &test_schema(DataType::Utf8)) + .unwrap(); + } + + #[rstest( + pattern, + case(col("c1", &test_schema(DataType::Utf8)).unwrap()), // not literal + case(lit(1)), // not literal string + case(lit("\\x{202e")), // wrong regex pattern + )] + #[should_panic] + fn test_scalar_regex_match_compile_error(pattern: Arc) { + let _ = ScalarRegexMatchExpr::new(false, false, lit("a"), pattern); + } +} diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 906ca9fd1093..d4d33ba32faf 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,6 +17,7 @@ use std::sync::Arc; +use crate::expressions::scalar_regex_match; use crate::scalar_function; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, @@ -191,7 +192,32 @@ pub fn create_physical_expr( // // There should be no coercion during physical // planning. - binary(lhs, *op, rhs, input_schema) + if let Expr::Literal( + ScalarValue::Null + | ScalarValue::Utf8(_) + | ScalarValue::Utf8View(_) + | ScalarValue::LargeUtf8(_), + ) = right.as_ref() + { + // handle literal regexp pattern case to `ScalarRegexMatchExpr` + match *op { + Operator::RegexMatch => { + scalar_regex_match(false, false, lhs, rhs, input_schema) + } + Operator::RegexNotMatch => { + scalar_regex_match(true, false, lhs, rhs, input_schema) + } + Operator::RegexIMatch => { + scalar_regex_match(false, true, lhs, rhs, input_schema) + } + Operator::RegexNotIMatch => { + scalar_regex_match(true, true, lhs, rhs, input_schema) + } + _ => binary(lhs, *op, rhs, input_schema), + } + } else { + binary(lhs, *op, rhs, input_schema) + } } Expr::Like(Like { negated, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index cb0235f1f20a..d1f81a115d60 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -830,6 +830,8 @@ message PhysicalExprNode { PhysicalExtensionExprNode extension = 19; UnknownColumn unknown_column = 20; + + PhysicalScalarRegexMatchExprNode scalar_regex_match_expr = 21; } } @@ -943,6 +945,13 @@ message PhysicalExtensionExprNode { repeated PhysicalExprNode inputs = 2; } +message PhysicalScalarRegexMatchExprNode { + bool negated = 1; + bool case_insensitive = 2; + PhysicalExprNode expr = 3; + PhysicalExprNode pattern = 4; +} + message FilterExecNode { PhysicalPlanNode input = 1; PhysicalExprNode expr = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f920e16d0a71..894d8329015e 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -13994,6 +13994,9 @@ impl serde::Serialize for PhysicalExprNode { physical_expr_node::ExprType::UnknownColumn(v) => { struct_ser.serialize_field("unknownColumn", v)?; } + physical_expr_node::ExprType::ScalarRegexMatchExpr(v) => { + struct_ser.serialize_field("scalarRegexMatchExpr", v)?; + } } } struct_ser.end() @@ -14036,6 +14039,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "extension", "unknown_column", "unknownColumn", + "scalar_regex_match_expr", + "scalarRegexMatchExpr", ]; #[allow(clippy::enum_variant_names)] @@ -14058,6 +14063,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { LikeExpr, Extension, UnknownColumn, + ScalarRegexMatchExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -14097,6 +14103,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "likeExpr" | "like_expr" => Ok(GeneratedField::LikeExpr), "extension" => Ok(GeneratedField::Extension), "unknownColumn" | "unknown_column" => Ok(GeneratedField::UnknownColumn), + "scalarRegexMatchExpr" | "scalar_regex_match_expr" => Ok(GeneratedField::ScalarRegexMatchExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -14243,6 +14250,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { return Err(serde::de::Error::duplicate_field("unknownColumn")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::UnknownColumn) +; + } + GeneratedField::ScalarRegexMatchExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarRegexMatchExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarRegexMatchExpr) ; } } @@ -15700,6 +15714,149 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { deserializer.deserialize_struct("datafusion.PhysicalPlanNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PhysicalScalarRegexMatchExprNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.negated { + len += 1; + } + if self.case_insensitive { + len += 1; + } + if self.expr.is_some() { + len += 1; + } + if self.pattern.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalScalarRegexMatchExprNode", len)?; + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; + } + if self.case_insensitive { + struct_ser.serialize_field("caseInsensitive", &self.case_insensitive)?; + } + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if let Some(v) = self.pattern.as_ref() { + struct_ser.serialize_field("pattern", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PhysicalScalarRegexMatchExprNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "negated", + "case_insensitive", + "caseInsensitive", + "expr", + "pattern", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Negated, + CaseInsensitive, + Expr, + Pattern, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "negated" => Ok(GeneratedField::Negated), + "caseInsensitive" | "case_insensitive" => Ok(GeneratedField::CaseInsensitive), + "expr" => Ok(GeneratedField::Expr), + "pattern" => Ok(GeneratedField::Pattern), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PhysicalScalarRegexMatchExprNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PhysicalScalarRegexMatchExprNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut negated__ = None; + let mut case_insensitive__ = None; + let mut expr__ = None; + let mut pattern__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); + } + negated__ = Some(map_.next_value()?); + } + GeneratedField::CaseInsensitive => { + if case_insensitive__.is_some() { + return Err(serde::de::Error::duplicate_field("caseInsensitive")); + } + case_insensitive__ = Some(map_.next_value()?); + } + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = map_.next_value()?; + } + GeneratedField::Pattern => { + if pattern__.is_some() { + return Err(serde::de::Error::duplicate_field("pattern")); + } + pattern__ = map_.next_value()?; + } + } + } + Ok(PhysicalScalarRegexMatchExprNode { + negated: negated__.unwrap_or_default(), + case_insensitive: case_insensitive__.unwrap_or_default(), + expr: expr__, + pattern: pattern__, + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalScalarRegexMatchExprNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PhysicalScalarUdfNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index a2be3207acab..3d4963b35cf7 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1165,7 +1165,7 @@ pub struct PhysicalExtensionNode { pub struct PhysicalExprNode { #[prost( oneof = "physical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21" )] pub expr_type: ::core::option::Option, } @@ -1216,6 +1216,10 @@ pub mod physical_expr_node { Extension(super::PhysicalExtensionExprNode), #[prost(message, tag = "20")] UnknownColumn(super::UnknownColumn), + #[prost(message, tag = "21")] + ScalarRegexMatchExpr( + ::prost::alloc::boxed::Box, + ), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1396,6 +1400,17 @@ pub struct PhysicalExtensionExprNode { pub inputs: ::prost::alloc::vec::Vec, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PhysicalScalarRegexMatchExprNode { + #[prost(bool, tag = "1")] + pub negated: bool, + #[prost(bool, tag = "2")] + pub case_insensitive: bool, + #[prost(message, optional, boxed, tag = "3")] + pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "4")] + pub pattern: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct FilterExecNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 5bf6218cb90e..c9d6aa7ace28 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -38,7 +38,7 @@ use datafusion::logical_expr::WindowFunctionDefinition; use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, - Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, + Literal, NegativeExpr, NotExpr, ScalarRegexMatchExpr, TryCastExpr, UnKnownColumn, }; use datafusion::physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion::physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; @@ -391,6 +391,26 @@ pub fn parse_physical_expr( .collect::>()?; (codec.try_decode_expr(extension.expr.as_slice(), &inputs)?) as _ } + ExprType::ScalarRegexMatchExpr(scalar_match_expr) => { + Arc::new(ScalarRegexMatchExpr::new( + scalar_match_expr.negated, + scalar_match_expr.case_insensitive, + parse_required_physical_expr( + scalar_match_expr.expr.as_deref(), + registry, + "expr", + input_schema, + codec, + )?, + parse_required_physical_expr( + scalar_match_expr.pattern.as_deref(), + registry, + "pattern", + input_schema, + codec, + )?, + )) + } }; Ok(pexpr) diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 3805b970591d..3f91ef93d492 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,7 +23,7 @@ use datafusion::physical_expr::window::{SlidingAggregateWindowExpr, StandardWind use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, - Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, + Literal, NegativeExpr, NotExpr, NthValue, ScalarRegexMatchExpr, TryCastExpr, UnKnownColumn, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr}; @@ -364,6 +364,25 @@ pub fn serialize_physical_expr( }, ))), }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::ScalarRegexMatchExpr(Box::new( + protobuf::PhysicalScalarRegexMatchExprNode { + negated: expr.negated(), + case_insensitive: expr.case_insensitive(), + expr: Some(Box::new(serialize_physical_expr( + expr.expr(), + codec, + )?)), + pattern: Some(Box::new(serialize_physical_expr( + expr.pattern(), + codec, + )?)), + }, + )), + ), + }) } else { let mut buf: Vec = vec![]; match codec.try_encode_expr(value, &mut buf) { From 02f58b23df0df0f72b7d3355ed101a663072d00d Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Wed, 18 Sep 2024 00:36:33 +0800 Subject: [PATCH 02/35] bench: add scalar regex match benchmarks --- datafusion/physical-expr/Cargo.toml | 4 + .../benches/scalar_regex_match.rs | 121 ++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 datafusion/physical-expr/benches/scalar_regex_match.rs diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index db3e0e10d816..f821134bf9f2 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -71,3 +71,7 @@ name = "case_when" [[bench]] harness = false name = "is_null" + +[[bench]] +harness = false +name = "scalar_regex_match" diff --git a/datafusion/physical-expr/benches/scalar_regex_match.rs b/datafusion/physical-expr/benches/scalar_regex_match.rs new file mode 100644 index 000000000000..680843c0cb56 --- /dev/null +++ b/datafusion/physical-expr/benches/scalar_regex_match.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::{RecordBatch, StringArray}; +use arrow_schema::{DataType, Field, Schema}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr_common::operator::Operator; +use datafusion_physical_expr::expressions::{binary, col, lit, scalar_regex_match}; +use hashbrown::HashMap; +use rand::distributions::{Alphanumeric, DistString}; + +/// make a record batch with one column and n rows +/// this record batch is single string column is used for +/// scalar regex match benchmarks +fn make_record_batch(rows: usize, string_length: usize, schema: Schema) -> RecordBatch { + let mut rng = rand::thread_rng(); + let mut array = Vec::with_capacity(rows); + for _ in 0..rows { + let data_line = Alphanumeric.sample_string(&mut rng, string_length); + array.push(Some(data_line)); + } + let array = StringArray::from(array); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap() +} + +fn scalar_regex_match_benchmark(c: &mut Criterion) { + // make common schema + let column = "string"; + let schema = Schema::new(vec![Field::new(column, DataType::Utf8, true)]); + + // meke test record batch + let test_batch = [ + (10, make_record_batch(10, 100, schema.clone())), + (100, make_record_batch(100, 100, schema.clone())), + (1000, make_record_batch(1000, 100, schema.clone())), + (2000, make_record_batch(2000, 100, schema.clone())), + ] + .iter() + .map(|(k, v)| (*k, v.clone())) + .collect::>(); + + // string column + let string_col = col(column, &schema).unwrap(); + + // some pattern literal + let pattern_lit = [ + ("email".to_string(), lit(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")), + ("url".to_string(), lit(r"^(https?|ftp)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]$")), + ("ip".to_string(), lit(r"^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$")), + ("phone".to_string(), lit(r"^(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}$")), + ("zip_code".to_string(), lit(r"^\d{5}(?:[-\s]\d{4})?$")), + ].iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>(); + + for (name, regexp_lit) in pattern_lit.iter() { + for (rows, batch) in test_batch.iter() { + for iter in [10, 20, 50, 100] { + // scalar regex match benchmarks + let bench_name = format!( + "scalar_regex_match_pattern_{}_rows_{}_iter_{}", + name, rows, iter + ); + c.bench_function(bench_name.as_str(), |b| { + let expr = scalar_regex_match( + false, + false, + string_col.clone(), + regexp_lit.clone(), + &schema, + ) + .unwrap(); + b.iter(|| { + for _ in 0..iter { + expr.evaluate(black_box(batch)).unwrap(); + } + }); + }); + + // binary regex match benchmarks + let bench_name = format!( + "binary_regex_match_pattern_{}_rows_{}_iter_{}", + name, rows, iter + ); + c.bench_function(bench_name.as_str(), |b| { + let expr = binary( + string_col.clone(), + Operator::RegexMatch, + regexp_lit.clone(), + &schema, + ) + .unwrap(); + b.iter(|| { + for _ in 0..iter { + expr.evaluate(black_box(batch)).unwrap(); + } + }); + }); + } + } + } +} + +criterion_group!(benches, scalar_regex_match_benchmark); +criterion_main!(benches); From 2dcb31770fbba5f9ce3ffd9f287162e843bb194e Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Wed, 18 Sep 2024 20:52:15 +0800 Subject: [PATCH 03/35] feat: apply scalar_regex_match optimize to similar_to case --- datafusion/physical-expr/src/planner.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index d4d33ba32faf..9afb3cc36bb1 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -258,6 +258,23 @@ pub fn create_physical_expr( create_physical_expr(expr, input_dfschema, execution_props)?; let physical_pattern = create_physical_expr(pattern, input_dfschema, execution_props)?; + + if let Expr::Literal( + ScalarValue::Null + | ScalarValue::Utf8(_) + | ScalarValue::Utf8View(_) + | ScalarValue::LargeUtf8(_), + ) = pattern.as_ref() + { + // handle literal regexp pattern case to `ScalarRegexMatchExpr` + return scalar_regex_match( + *negated, + *case_insensitive, + physical_expr, + physical_pattern, + input_schema, + ); + } similar_to(*negated, *case_insensitive, physical_expr, physical_pattern) } Expr::Case(case) => { From 3a33a71b39394f41550eeeb38a19f6aa6fa1b460 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Wed, 18 Sep 2024 22:17:21 +0800 Subject: [PATCH 04/35] minor: regen datafusion protobuf --- datafusion/physical-expr/Cargo.toml | 1 + .../src/expressions/scalar_regex_match.rs | 100 +++++++++--------- 2 files changed, 52 insertions(+), 49 deletions(-) diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index f821134bf9f2..522eaea8f67c 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -53,6 +53,7 @@ itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "^1.0" petgraph = "0.6.2" +regex = { workspace = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs index badb00659576..cc446f3328d5 100644 --- a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs +++ b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs @@ -22,11 +22,16 @@ use arrow_array::{ }; use arrow_buffer::BooleanBufferBuilder; use arrow_schema::{DataType, Schema}; -use datafusion_common::ScalarValue; +use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::ColumnarValue; -use datafusion_physical_expr_common::physical_expr::{down_cast_any_ref, PhysicalExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use regex::Regex; -use std::{any::Any, hash::Hash, sync::Arc}; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter, Result as FmtResult}, + hash::Hash, + sync::Arc, +}; /// ScalarRegexMatchExpr /// Only used when evaluating regexp matching with literal pattern. @@ -133,9 +138,7 @@ impl ScalarRegexMatchExpr { (true, true) => "NOT IMATCH", } } -} -impl ScalarRegexMatchExpr { /// Evaluate the scalar regex match expression match array value fn evaluate_array( &self, @@ -200,16 +203,9 @@ impl ScalarRegexMatchExpr { } } -impl std::hash::Hash for ScalarRegexMatchExpr { - fn hash(&self, state: &mut H) { - self.negated.hash(state); - self.case_insensitive.hash(state); - self.expr.hash(state); - self.pattern.hash(state); - } -} +impl Eq for ScalarRegexMatchExpr {} -impl std::cmp::PartialEq for ScalarRegexMatchExpr { +impl PartialEq for ScalarRegexMatchExpr { fn eq(&self, other: &Self) -> bool { self.negated.eq(&other.negated) && self.case_insensitive.eq(&self.case_insensitive) @@ -218,8 +214,17 @@ impl std::cmp::PartialEq for ScalarRegexMatchExpr { } } -impl std::fmt::Debug for ScalarRegexMatchExpr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Hash for ScalarRegexMatchExpr { + fn hash(&self, state: &mut H) { + self.negated.hash(state); + self.case_insensitive.hash(state); + self.expr.hash(state); + self.pattern.hash(state); + } +} + +impl Debug for ScalarRegexMatchExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { f.debug_struct("ScalarRegexMatchExpr") .field("negated", &self.negated) .field("case_insensitive", &self.case_insensitive) @@ -229,35 +234,26 @@ impl std::fmt::Debug for ScalarRegexMatchExpr { } } -impl std::fmt::Display for ScalarRegexMatchExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl Display for ScalarRegexMatchExpr { + fn fmt(&self, f: &mut Formatter) -> FmtResult { write!(f, "{} {} {}", self.expr, self.op_name(), self.pattern) } } impl PhysicalExpr for ScalarRegexMatchExpr { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } - fn data_type( - &self, - _: &arrow_schema::Schema, - ) -> datafusion_common::Result { + fn data_type(&self, _: &Schema) -> DFResult { Ok(DataType::Boolean) } - fn nullable( - &self, - input_schema: &arrow_schema::Schema, - ) -> datafusion_common::Result { + fn nullable(&self, input_schema: &Schema) -> DFResult { Ok(self.expr.nullable(input_schema)? || self.pattern.nullable(input_schema)?) } - fn evaluate( - &self, - batch: &arrow_array::RecordBatch, - ) -> datafusion_common::Result { + fn evaluate(&self, batch: &arrow_array::RecordBatch) -> DFResult { self.expr .evaluate(batch) .and_then(|lhs| { @@ -274,14 +270,14 @@ impl PhysicalExpr for ScalarRegexMatchExpr { .map(ColumnarValue::Array) } - fn children(&self) -> Vec<&std::sync::Arc> { + fn children(&self) -> Vec<&Arc> { vec![&self.expr, &self.pattern] } fn with_new_children( - self: std::sync::Arc, - children: Vec>, - ) -> datafusion_common::Result> { + self: Arc, + children: Vec>, + ) -> DFResult> { Ok(Arc::new(ScalarRegexMatchExpr::new( self.negated, self.case_insensitive, @@ -290,18 +286,24 @@ impl PhysicalExpr for ScalarRegexMatchExpr { ))) } - fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) { - let mut s = state; - self.hash(&mut s); - } -} - -impl PartialEq for ScalarRegexMatchExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self == x) - .unwrap_or(false) + fn evaluate_selection( + &self, + batch: &arrow_array::RecordBatch, + selection: &BooleanArray, + ) -> DFResult { + let tmp_batch = arrow::compute::filter_record_batch(batch, selection)?; + + let tmp_result = self.evaluate(&tmp_batch)?; + + if batch.num_rows() == tmp_batch.num_rows() { + // All values from the `selection` filter are true. + Ok(tmp_result) + } else if let ColumnarValue::Array(a) = tmp_result { + datafusion_physical_expr_common::utils::scatter(selection, a.as_ref()) + .map(ColumnarValue::Array) + } else { + Ok(tmp_result) + } } } @@ -310,7 +312,7 @@ fn array_regexp_match( array: &dyn ArrayAccessor, regex: &Regex, negated: bool, -) -> datafusion_common::Result { +) -> DFResult { let null_bit_buffer = array.nulls().map(|x| x.inner().sliced()); let mut buffer_builder = BooleanBufferBuilder::new(array.len()); @@ -359,7 +361,7 @@ pub fn scalar_regex_match( expr: Arc, pattern: Arc, input_schema: &Schema, -) -> datafusion_common::Result> { +) -> DFResult> { let valid_data_type = |data_type: &DataType| { if !matches!( data_type, From 930b7a8d72eab29bd313324b8d9525edbe633d3b Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Sun, 10 Nov 2024 00:24:09 +0800 Subject: [PATCH 05/35] bench: improve scalar_regex_match --- .../benches/scalar_regex_match.rs | 106 ++++++++++-------- 1 file changed, 62 insertions(+), 44 deletions(-) diff --git a/datafusion/physical-expr/benches/scalar_regex_match.rs b/datafusion/physical-expr/benches/scalar_regex_match.rs index 680843c0cb56..c9cad78e5807 100644 --- a/datafusion/physical-expr/benches/scalar_regex_match.rs +++ b/datafusion/physical-expr/benches/scalar_regex_match.rs @@ -21,8 +21,10 @@ use arrow_array::{RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr_common::operator::Operator; -use datafusion_physical_expr::expressions::{binary, col, lit, scalar_regex_match}; -use hashbrown::HashMap; +use datafusion_physical_expr::{ + expressions::{binary, col, lit, scalar_regex_match}, + PhysicalExpr, +}; use rand::distributions::{Alphanumeric, DistString}; /// make a record batch with one column and n rows @@ -39,83 +41,99 @@ fn make_record_batch(rows: usize, string_length: usize, schema: Schema) -> Recor RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap() } -fn scalar_regex_match_benchmark(c: &mut Criterion) { +/// initialize benchmark data and pattern literals +fn init_benchmark() -> ( + Vec<(usize, RecordBatch)>, + Schema, + Arc, + Vec<(String, Arc)>, +) { // make common schema let column = "string"; let schema = Schema::new(vec![Field::new(column, DataType::Utf8, true)]); // meke test record batch - let test_batch = [ - (10, make_record_batch(10, 100, schema.clone())), - (100, make_record_batch(100, 100, schema.clone())), - (1000, make_record_batch(1000, 100, schema.clone())), - (2000, make_record_batch(2000, 100, schema.clone())), - ] - .iter() - .map(|(k, v)| (*k, v.clone())) - .collect::>(); + let batch_data = vec![ + // (10_usize, make_record_batch(10, 100, schema.clone())), + // (100_usize, make_record_batch(100, 100, schema.clone())), + // (1000_usize, make_record_batch(1000, 100, schema.clone())), + (2000_usize, make_record_batch(2000, 100, schema.clone())), + ]; // string column let string_col = col(column, &schema).unwrap(); // some pattern literal - let pattern_lit = [ - ("email".to_string(), lit(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")), - ("url".to_string(), lit(r"^(https?|ftp)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]$")), - ("ip".to_string(), lit(r"^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$")), - ("phone".to_string(), lit(r"^(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}$")), - ("zip_code".to_string(), lit(r"^\d{5}(?:[-\s]\d{4})?$")), - ].iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect::>(); + let pattern_lit = vec![ + ( + format!("email"), + lit(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"), + ), + ( + format!("url"), + lit(r"^(https?|ftp)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]$"), + ), + ( + format!("ip"), + lit( + r"^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$", + ), + ), + ( + format!("phone"), + lit(r"^(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}$"), + ), + (format!("zip_code"), lit(r"^\d{5}(?:[-\s]\d{4})?$")), + ]; + (batch_data, schema, string_col, pattern_lit) +} +fn regex_match_benchmark(c: &mut Criterion) { + let (batch_data, schema, string_col, pattern_lit) = init_benchmark(); + // let record_batch_run_times = [10, 20, 50, 100]; + let record_batch_run_times = [10]; for (name, regexp_lit) in pattern_lit.iter() { - for (rows, batch) in test_batch.iter() { - for iter in [10, 20, 50, 100] { - // scalar regex match benchmarks - let bench_name = format!( - "scalar_regex_match_pattern_{}_rows_{}_iter_{}", - name, rows, iter - ); - c.bench_function(bench_name.as_str(), |b| { - let expr = scalar_regex_match( - false, - false, + for (rows, batch) in batch_data.iter() { + for run_time in record_batch_run_times { + let group_name = + format!("regex_{}_rows_{}_run_time_{}", name, rows, run_time); + let mut group = c.benchmark_group(group_name.as_str()); + // binary expr match benchmarks + group.bench_function("binary_expr_match", |b| { + let expr = binary( string_col.clone(), + Operator::RegexMatch, regexp_lit.clone(), &schema, ) .unwrap(); b.iter(|| { - for _ in 0..iter { + for _ in 0..run_time { expr.evaluate(black_box(batch)).unwrap(); } }); }); - - // binary regex match benchmarks - let bench_name = format!( - "binary_regex_match_pattern_{}_rows_{}_iter_{}", - name, rows, iter - ); - c.bench_function(bench_name.as_str(), |b| { - let expr = binary( + // scalar regex match benchmarks + group.bench_function("scalar_regex_match", |b| { + let expr = scalar_regex_match( + false, + false, string_col.clone(), - Operator::RegexMatch, regexp_lit.clone(), &schema, ) .unwrap(); b.iter(|| { - for _ in 0..iter { + for _ in 0..run_time { expr.evaluate(black_box(batch)).unwrap(); } }); }); + group.finish(); } } } } -criterion_group!(benches, scalar_regex_match_benchmark); +criterion_group!(benches, regex_match_benchmark); criterion_main!(benches); From 5f2b555e783492bdf1904d343dcdd6f799441493 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Sun, 10 Nov 2024 00:59:04 +0800 Subject: [PATCH 06/35] minor: update cargo.lock --- datafusion-cli/Cargo.lock | 1 + .../physical-expr/benches/scalar_regex_match.rs | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index c871b2fdda08..2b23b1983772 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1562,6 +1562,7 @@ dependencies = [ "log", "paste", "petgraph", + "regex", ] [[package]] diff --git a/datafusion/physical-expr/benches/scalar_regex_match.rs b/datafusion/physical-expr/benches/scalar_regex_match.rs index c9cad78e5807..dd958c4dd63a 100644 --- a/datafusion/physical-expr/benches/scalar_regex_match.rs +++ b/datafusion/physical-expr/benches/scalar_regex_match.rs @@ -42,6 +42,7 @@ fn make_record_batch(rows: usize, string_length: usize, schema: Schema) -> Recor } /// initialize benchmark data and pattern literals +#[allow(clippy::type_complexity)] fn init_benchmark() -> ( Vec<(usize, RecordBatch)>, Schema, @@ -66,24 +67,24 @@ fn init_benchmark() -> ( // some pattern literal let pattern_lit = vec![ ( - format!("email"), + "email".to_string(), lit(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"), ), ( - format!("url"), + "url".to_string(), lit(r"^(https?|ftp)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]$"), ), ( - format!("ip"), + "ip".to_string(), lit( r"^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$", ), ), ( - format!("phone"), + "phone".to_string(), lit(r"^(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}$"), ), - (format!("zip_code"), lit(r"^\d{5}(?:[-\s]\d{4})?$")), + ("zip_code".to_string(), lit(r"^\d{5}(?:[-\s]\d{4})?$")), ]; (batch_data, schema, string_col, pattern_lit) } From f309341b64b307aac874d2b236a0aa659d4d1efd Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Tue, 3 Dec 2024 23:15:01 +0800 Subject: [PATCH 07/35] fix: fix wrong merge conflict --- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/physical_plan/to_proto.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d1f81a115d60..eb5e5fe3f324 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -830,7 +830,7 @@ message PhysicalExprNode { PhysicalExtensionExprNode extension = 19; UnknownColumn unknown_column = 20; - + PhysicalScalarRegexMatchExprNode scalar_regex_match_expr = 21; } } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 3f91ef93d492..43af61a51716 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,7 +23,7 @@ use datafusion::physical_expr::window::{SlidingAggregateWindowExpr, StandardWind use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, - Literal, NegativeExpr, NotExpr, NthValue, ScalarRegexMatchExpr, TryCastExpr, UnKnownColumn, + Literal, NegativeExpr, NotExpr, ScalarRegexMatchExpr, TryCastExpr, UnKnownColumn, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr}; From a0ba73a6c57d79b49d99a909d808c56d4fc93eb2 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Sat, 7 Dec 2024 01:54:03 +0800 Subject: [PATCH 08/35] bench: init expr in scalar_regex_match bench iter --- .../benches/scalar_regex_match.rs | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-expr/benches/scalar_regex_match.rs b/datafusion/physical-expr/benches/scalar_regex_match.rs index dd958c4dd63a..bed346877560 100644 --- a/datafusion/physical-expr/benches/scalar_regex_match.rs +++ b/datafusion/physical-expr/benches/scalar_regex_match.rs @@ -101,14 +101,14 @@ fn regex_match_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group(group_name.as_str()); // binary expr match benchmarks group.bench_function("binary_expr_match", |b| { - let expr = binary( - string_col.clone(), - Operator::RegexMatch, - regexp_lit.clone(), - &schema, - ) - .unwrap(); b.iter(|| { + let expr = binary( + string_col.clone(), + Operator::RegexMatch, + regexp_lit.clone(), + &schema, + ) + .unwrap(); for _ in 0..run_time { expr.evaluate(black_box(batch)).unwrap(); } @@ -116,15 +116,15 @@ fn regex_match_benchmark(c: &mut Criterion) { }); // scalar regex match benchmarks group.bench_function("scalar_regex_match", |b| { - let expr = scalar_regex_match( - false, - false, - string_col.clone(), - regexp_lit.clone(), - &schema, - ) - .unwrap(); b.iter(|| { + let expr = scalar_regex_match( + false, + false, + string_col.clone(), + regexp_lit.clone(), + &schema, + ) + .unwrap(); for _ in 0..run_time { expr.evaluate(black_box(batch)).unwrap(); } From 7014ac34a63ebda204a0ae708993864b28c5c6bf Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Sat, 7 Dec 2024 22:32:43 +0800 Subject: [PATCH 09/35] bench: diff batch run over in scalar_regex_match --- .../benches/scalar_regex_match.rs | 123 ++++++++++-------- 1 file changed, 69 insertions(+), 54 deletions(-) diff --git a/datafusion/physical-expr/benches/scalar_regex_match.rs b/datafusion/physical-expr/benches/scalar_regex_match.rs index bed346877560..139c5049c87c 100644 --- a/datafusion/physical-expr/benches/scalar_regex_match.rs +++ b/datafusion/physical-expr/benches/scalar_regex_match.rs @@ -19,46 +19,63 @@ use std::sync::Arc; use arrow_array::{RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use criterion::{criterion_group, criterion_main, Criterion}; use datafusion_expr_common::operator::Operator; use datafusion_physical_expr::{ expressions::{binary, col, lit, scalar_regex_match}, PhysicalExpr, }; -use rand::distributions::{Alphanumeric, DistString}; +use rand::{ + distributions::{Alphanumeric, DistString}, + rngs::StdRng, + SeedableRng, +}; /// make a record batch with one column and n rows /// this record batch is single string column is used for /// scalar regex match benchmarks -fn make_record_batch(rows: usize, string_length: usize, schema: Schema) -> RecordBatch { - let mut rng = rand::thread_rng(); - let mut array = Vec::with_capacity(rows); - for _ in 0..rows { - let data_line = Alphanumeric.sample_string(&mut rng, string_length); - array.push(Some(data_line)); +fn make_record_batch( + batch_iter: usize, + batch_size: usize, + string_len: usize, + schema: &Schema, +) -> Vec { + let mut rng = StdRng::from_seed([123; 32]); + let mut batches = vec![]; + for _ in 0..batch_iter { + let array = (0..batch_size) + .map(|_| Some(Alphanumeric.sample_string(&mut rng, string_len))) + .collect::>(); + let array = StringArray::from(array); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]) + .unwrap(); + batches.push(batch); } - let array = StringArray::from(array); - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap() + batches } /// initialize benchmark data and pattern literals #[allow(clippy::type_complexity)] fn init_benchmark() -> ( - Vec<(usize, RecordBatch)>, + Vec<(usize, usize, Vec)>, Schema, Arc, Vec<(String, Arc)>, ) { // make common schema - let column = "string"; + let column = "s"; let schema = Schema::new(vec![Field::new(column, DataType::Utf8, true)]); // meke test record batch let batch_data = vec![ - // (10_usize, make_record_batch(10, 100, schema.clone())), - // (100_usize, make_record_batch(100, 100, schema.clone())), - // (1000_usize, make_record_batch(1000, 100, schema.clone())), - (2000_usize, make_record_batch(2000, 100, schema.clone())), + // (20, 10_usize, make_record_batch(20, 10, 100, schema.clone())), + // (20, 100_usize, make_record_batch(20, 100, 100, schema.clone())), + // (20, 1000_usize, make_record_batch(20, 1000, 100, schema.clone())), + ( + 128_usize, + 4096_usize, + make_record_batch(128, 4096, 100, &schema), + ), ]; // string column @@ -91,47 +108,45 @@ fn init_benchmark() -> ( fn regex_match_benchmark(c: &mut Criterion) { let (batch_data, schema, string_col, pattern_lit) = init_benchmark(); - // let record_batch_run_times = [10, 20, 50, 100]; - let record_batch_run_times = [10]; for (name, regexp_lit) in pattern_lit.iter() { - for (rows, batch) in batch_data.iter() { - for run_time in record_batch_run_times { - let group_name = - format!("regex_{}_rows_{}_run_time_{}", name, rows, run_time); - let mut group = c.benchmark_group(group_name.as_str()); - // binary expr match benchmarks - group.bench_function("binary_expr_match", |b| { - b.iter(|| { - let expr = binary( - string_col.clone(), - Operator::RegexMatch, - regexp_lit.clone(), - &schema, - ) - .unwrap(); - for _ in 0..run_time { - expr.evaluate(black_box(batch)).unwrap(); - } - }); + for (batch_iter, batch_size, batches) in batch_data.iter() { + let group_name = format!( + "regex_{}_batch_iter_{}_batch_size_{}", + name, batch_iter, batch_size + ); + let mut group = c.benchmark_group(group_name.as_str()); + // binary expr match benchmarks + group.bench_function("binary_expr_match", |b| { + b.iter(|| { + let expr = binary( + string_col.clone(), + Operator::RegexMatch, + regexp_lit.clone(), + &schema, + ) + .unwrap(); + for batch in batches.iter() { + expr.evaluate(batch).unwrap(); + } }); - // scalar regex match benchmarks - group.bench_function("scalar_regex_match", |b| { - b.iter(|| { - let expr = scalar_regex_match( - false, - false, - string_col.clone(), - regexp_lit.clone(), - &schema, - ) - .unwrap(); - for _ in 0..run_time { - expr.evaluate(black_box(batch)).unwrap(); - } - }); + }); + // scalar regex match benchmarks + group.bench_function("scalar_regex_match", |b| { + b.iter(|| { + let expr = scalar_regex_match( + false, + false, + string_col.clone(), + regexp_lit.clone(), + &schema, + ) + .unwrap(); + for batch in batches.iter() { + expr.evaluate(batch).unwrap(); + } }); - group.finish(); - } + }); + group.finish(); } } } From ef66c5623e4c2395e435b71bc8157bb4993ca6d6 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Mon, 9 Dec 2024 01:10:32 +0800 Subject: [PATCH 10/35] improve: improve performance of scalar_regex_match --- .../benches/scalar_regex_match.rs | 22 ++++++- .../src/expressions/scalar_regex_match.rs | 58 +++++-------------- 2 files changed, 33 insertions(+), 47 deletions(-) diff --git a/datafusion/physical-expr/benches/scalar_regex_match.rs b/datafusion/physical-expr/benches/scalar_regex_match.rs index 139c5049c87c..9c6826800600 100644 --- a/datafusion/physical-expr/benches/scalar_regex_match.rs +++ b/datafusion/physical-expr/benches/scalar_regex_match.rs @@ -38,14 +38,18 @@ fn make_record_batch( batch_iter: usize, batch_size: usize, string_len: usize, + matched_str: &[&str], schema: &Schema, ) -> Vec { - let mut rng = StdRng::from_seed([123; 32]); + let mut rng = StdRng::seed_from_u64(12345); let mut batches = vec![]; for _ in 0..batch_iter { - let array = (0..batch_size) + let mut array = (0..batch_size) .map(|_| Some(Alphanumeric.sample_string(&mut rng, string_len))) .collect::>(); + for v in matched_str { + array.push(Some(v.to_string())); + } let array = StringArray::from(array); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]) .unwrap(); @@ -74,7 +78,19 @@ fn init_benchmark() -> ( ( 128_usize, 4096_usize, - make_record_batch(128, 4096, 100, &schema), + make_record_batch( + 128, + 4096, + 100, + &[ + "example@email.com", + "http://example.com", + "123.4.5.6", + "1236787788", + "55555", + ], + &schema, + ), ), ]; diff --git a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs index cc446f3328d5..b4e1e92306cb 100644 --- a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs +++ b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs @@ -16,11 +16,10 @@ // under the License. use super::Literal; -use arrow::array::ArrayData; use arrow_array::{ Array, ArrayAccessor, BooleanArray, LargeStringArray, StringArray, StringViewArray, }; -use arrow_buffer::BooleanBufferBuilder; +use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder}; use arrow_schema::{DataType, Schema}; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::ColumnarValue; @@ -144,6 +143,8 @@ impl ScalarRegexMatchExpr { &self, array: &Arc, ) -> datafusion_common::Result { + /// downcast_string_array downcast a [`ArrayRef`] to specific array type + /// example: [`StringArray`], [`LargeStringArray`], [`StringViewArray`] macro_rules! downcast_string_array { ($ARRAY:expr, $ARRAY_TYPE:ident, $ERR_MSG:expr) => { &($ARRAY @@ -157,7 +158,7 @@ impl ScalarRegexMatchExpr { Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))) }, DataType::Utf8 => array_regexp_match( - downcast_string_array!(array, StringArray, "Failed to downcast StringArray"), + downcast_string_array!(array, StringArray, "Failed to downcast StringArray"), self.compiled.as_ref().unwrap(), self.negated, ), @@ -285,26 +286,6 @@ impl PhysicalExpr for ScalarRegexMatchExpr { Arc::clone(&children[1]), ))) } - - fn evaluate_selection( - &self, - batch: &arrow_array::RecordBatch, - selection: &BooleanArray, - ) -> DFResult { - let tmp_batch = arrow::compute::filter_record_batch(batch, selection)?; - - let tmp_result = self.evaluate(&tmp_batch)?; - - if batch.num_rows() == tmp_batch.num_rows() { - // All values from the `selection` filter are true. - Ok(tmp_result) - } else if let ColumnarValue::Array(a) = tmp_result { - datafusion_physical_expr_common::utils::scatter(selection, a.as_ref()) - .map(ColumnarValue::Array) - } else { - Ok(tmp_result) - } - } } /// It is used for scalar regexp matching and copy from arrow-rs @@ -313,31 +294,20 @@ fn array_regexp_match( regex: &Regex, negated: bool, ) -> DFResult { - let null_bit_buffer = array.nulls().map(|x| x.inner().sliced()); - let mut buffer_builder = BooleanBufferBuilder::new(array.len()); - - if regex.as_str().is_empty() { - buffer_builder.append_n(array.len(), true); + let null_buffer = array.logical_nulls(); + let bool_buffer = if regex.as_str().is_empty() { + BooleanBuffer::new_set(array.len()) } else { + let mut bool_buffer_builder = BooleanBufferBuilder::new(array.len()); + bool_buffer_builder.advance(array.len()); for i in 0..array.len() { - let value = array.value(i); - buffer_builder.append(regex.is_match(value)); + let value = unsafe { array.value_unchecked(i) }; + bool_buffer_builder.set_bit(i, regex.is_match(value)); } - } - - let buffer = buffer_builder.into(); - let bool_array = BooleanArray::from(unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - array.len(), - None, - null_bit_buffer, - 0, - vec![buffer], - vec![], - ) - }); + bool_buffer_builder.finish() + }; + let bool_array = BooleanArray::new(bool_buffer, null_buffer); let bool_array = if negated { arrow::compute::kernels::boolean::not(&bool_array) } else { From 13adab59f8318b466418c64303b700e662279887 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Fri, 6 Dec 2024 12:25:44 -0800 Subject: [PATCH 11/35] Minor: Comment temporary function for documentation migration (#13669) * Minor: Comment temporary function for documentation migration * Minor: Comment temporary function for documentation migration --- datafusion/core/src/bin/print_functions_docs.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index b58f6e47d333..8b453d5e9698 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -88,6 +88,7 @@ fn print_window_docs() -> Result { // the migration of UDF documentation generation from code based // to attribute based // To be removed +#[allow(dead_code)] fn save_doc_code_text(documentation: &Documentation, name: &str) { let attr_text = documentation.to_doc_attribute(); @@ -182,7 +183,7 @@ fn print_docs( }; // Temporary for doc gen migration, see `save_doc_code_text` comments - save_doc_code_text(documentation, &name); + // save_doc_code_text(documentation, &name); // first, the name, description and syntax example let _ = write!( From 7cfaf1e35cb52b62506577a338effd51d260f08b Mon Sep 17 00:00:00 2001 From: Oleks V Date: Fri, 6 Dec 2024 13:13:05 -0800 Subject: [PATCH 12/35] Minor: Rephrase MSRV policy to be more explanatory (#13668) * Minor: Rephrase MSRV policy to be more explanatory Co-authored-by: Andrew Lamb * MSRV policy update --------- Co-authored-by: Andrew Lamb --- README.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 553097552418..2e4f2c347fe5 100644 --- a/README.md +++ b/README.md @@ -126,14 +126,17 @@ Optional features: ## Rust Version Compatibility Policy -DataFusion's Minimum Required Stable Rust Version (MSRV) policy is to support stable [4 latest -Rust versions](https://releases.rs) OR the stable minor Rust version as of 4 months, whichever is lower. +The Rust toolchain releases are tracked at [Rust Versions](https://releases.rs) and follow +[semantic versioning](https://semver.org/). A Rust toolchain release can be identified +by a version string like `1.80.0`, or more generally `major.minor.patch`. + +DataFusion's supports the last 4 stable Rust minor versions released and any such versions released within the last 4 months. For example, given the releases `1.78.0`, `1.79.0`, `1.80.0`, `1.80.1` and `1.81.0` DataFusion will support 1.78.0, which is 3 minor versions prior to the most minor recent `1.81`. -If a hotfix is released for the minimum supported Rust version (MSRV), the MSRV will be the minor version with all hotfixes, even if it surpasses the four-month window. +Note: If a Rust hotfix is released for the current MSRV, the MSRV will be updated to the specific minor version that includes all applicable hotfixes preceding other policies. -We enforce this policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) +DataFusion enforces MSRV policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) ## DataFusion API evolution policy From 98b7488a3c25e38bef150bae9aac46ef8521f805 Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Sat, 7 Dec 2024 16:00:07 +0200 Subject: [PATCH 13/35] fix: repartitioned reads of CSV with custom line terminator (#13677) --- .../core/src/datasource/physical_plan/csv.rs | 4 ++- .../core/src/datasource/physical_plan/json.rs | 2 +- .../core/src/datasource/physical_plan/mod.rs | 11 +++--- .../sqllogictest/test_files/csv_files.slt | 36 ++++++++++++++----- 4 files changed, 38 insertions(+), 15 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 0c41f69c7691..c54c663dca7d 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -612,11 +612,13 @@ impl FileOpener for CsvOpener { } let store = Arc::clone(&self.config.object_store); + let terminator = self.config.terminator; Ok(Box::pin(async move { // Current partition contains bytes [start_byte, end_byte) (might contain incomplete lines at boundaries) - let calculated_range = calculate_range(&file_meta, &store).await?; + let calculated_range = + calculate_range(&file_meta, &store, terminator).await?; let range = match calculated_range { RangeCalculation::Range(None) => None, diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index c07e8ca74543..5c70968fbb42 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -273,7 +273,7 @@ impl FileOpener for JsonOpener { let file_compression_type = self.file_compression_type.to_owned(); Ok(Box::pin(async move { - let calculated_range = calculate_range(&file_meta, &store).await?; + let calculated_range = calculate_range(&file_meta, &store, None).await?; let range = match calculated_range { RangeCalculation::Range(None) => None, diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 449b7bb43519..3146d124d9f1 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -426,9 +426,11 @@ enum RangeCalculation { async fn calculate_range( file_meta: &FileMeta, store: &Arc, + terminator: Option, ) -> Result { let location = file_meta.location(); let file_size = file_meta.object_meta.size; + let newline = terminator.unwrap_or(b'\n'); match file_meta.range { None => Ok(RangeCalculation::Range(None)), @@ -436,13 +438,13 @@ async fn calculate_range( let (start, end) = (start as usize, end as usize); let start_delta = if start != 0 { - find_first_newline(store, location, start - 1, file_size).await? + find_first_newline(store, location, start - 1, file_size, newline).await? } else { 0 }; let end_delta = if end != file_size { - find_first_newline(store, location, end - 1, file_size).await? + find_first_newline(store, location, end - 1, file_size, newline).await? } else { 0 }; @@ -462,7 +464,7 @@ async fn calculate_range( /// within an object, such as a file, in an object store. /// /// This function scans the contents of the object starting from the specified `start` position -/// up to the `end` position, looking for the first occurrence of a newline (`'\n'`) character. +/// up to the `end` position, looking for the first occurrence of a newline character. /// It returns the position of the first newline relative to the start of the range. /// /// Returns a `Result` wrapping a `usize` that represents the position of the first newline character found within the specified range. If no newline is found, it returns the length of the scanned data, effectively indicating the end of the range. @@ -474,6 +476,7 @@ async fn find_first_newline( location: &Path, start: usize, end: usize, + newline: u8, ) -> Result { let options = GetOptions { range: Some(GetRange::Bounded(start..end)), @@ -486,7 +489,7 @@ async fn find_first_newline( let mut index = 0; while let Some(chunk) = result_stream.next().await.transpose()? { - if let Some(position) = chunk.iter().position(|&byte| byte == b'\n') { + if let Some(position) = chunk.iter().position(|&byte| byte == newline) { return Ok(index + position); } diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt index 01d0f4ac39bd..5906c6a19bb8 100644 --- a/datafusion/sqllogictest/test_files/csv_files.slt +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -350,15 +350,33 @@ col2 TEXT LOCATION '../core/tests/data/cr_terminator.csv' OPTIONS ('format.terminator' E'\r', 'format.has_header' 'true'); -# TODO: It should be passed but got the error: External error: query failed: DataFusion error: Object Store error: Generic LocalFileSystem error: Requested range was invalid -# See the issue: https://github.com/apache/datafusion/issues/12328 -# query TT -# select * from stored_table_with_cr_terminator; -# ---- -# id0 value0 -# id1 value1 -# id2 value2 -# id3 value3 +# Check single-thread reading of CSV with custom line terminator +statement ok +SET datafusion.optimizer.repartition_file_min_size = 10485760; + +query TT +select * from stored_table_with_cr_terminator; +---- +id0 value0 +id1 value1 +id2 value2 +id3 value3 + +# Check repartitioned reading of CSV with custom line terminator +statement ok +SET datafusion.optimizer.repartition_file_min_size = 1; + +query TT +select * from stored_table_with_cr_terminator order by col1; +---- +id0 value0 +id1 value1 +id2 value2 +id3 value3 + +# Reset repartition_file_min_size to default value +statement ok +SET datafusion.optimizer.repartition_file_min_size = 10485760; statement ok drop table stored_table_with_cr_terminator; From 14dcf209c0634fac8dcf7b202d88e294472802d6 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Sat, 7 Dec 2024 22:13:56 +0100 Subject: [PATCH 14/35] chore: macros crate cleanup (#13685) * Remove unused dependencies from macros crate * rename macro lib to user_doc --- datafusion-cli/Cargo.lock | 1 - datafusion/macros/Cargo.toml | 4 ++-- datafusion/macros/src/{lib.rs => user_doc.rs} | 0 3 files changed, 2 insertions(+), 3 deletions(-) rename datafusion/macros/src/{lib.rs => user_doc.rs} (100%) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 2b23b1983772..4ef828a5a8f7 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1517,7 +1517,6 @@ dependencies = [ name = "datafusion-macros" version = "43.0.0" dependencies = [ - "datafusion-doc", "proc-macro2", "quote", "syn", diff --git a/datafusion/macros/Cargo.toml b/datafusion/macros/Cargo.toml index c5ac9d08dffa..07aa07fa927a 100644 --- a/datafusion/macros/Cargo.toml +++ b/datafusion/macros/Cargo.toml @@ -32,11 +32,11 @@ workspace = true [lib] name = "datafusion_macros" -path = "src/lib.rs" +# lib.rs to be re-added in the future +path = "src/user_doc.rs" proc-macro = true [dependencies] -datafusion-doc = { workspace = true } proc-macro2 = "1.0" quote = "1.0.37" syn = { version = "2.0.79", features = ["full"] } diff --git a/datafusion/macros/src/lib.rs b/datafusion/macros/src/user_doc.rs similarity index 100% rename from datafusion/macros/src/lib.rs rename to datafusion/macros/src/user_doc.rs From f6cafba1654ae4a4f3ede7c360d7f203f1ca1f54 Mon Sep 17 00:00:00 2001 From: Jiashen Cao Date: Sun, 8 Dec 2024 02:48:57 -0500 Subject: [PATCH 15/35] Refactor regexplike signature (#13394) * update * update * update * clean up errors * fix flags types * fix failed example --- datafusion-examples/examples/regexp.rs | 2 +- datafusion/functions/src/regex/regexplike.rs | 50 +++++++++++-------- .../test_files/string/string_view.slt | 2 +- 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/datafusion-examples/examples/regexp.rs b/datafusion-examples/examples/regexp.rs index 02e74bae22af..5419efd2faea 100644 --- a/datafusion-examples/examples/regexp.rs +++ b/datafusion-examples/examples/regexp.rs @@ -148,7 +148,7 @@ async fn main() -> Result<()> { // invalid flags will result in an error let result = ctx - .sql(r"select regexp_like('\b4(?!000)\d\d\d\b', 4010, 'g')") + .sql(r"select regexp_like('\b4(?!000)\d\d\d\b', '4010', 'g')") .await? .collect() .await; diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index 49e57776c7b8..1c826b12ef8f 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -81,26 +81,7 @@ impl RegexpLikeFunc { pub fn new() -> Self { Self { signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Utf8View, Utf8]), - TypeSignature::Exact(vec![Utf8View, Utf8View]), - TypeSignature::Exact(vec![Utf8View, LargeUtf8]), - TypeSignature::Exact(vec![Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8, Utf8View]), - TypeSignature::Exact(vec![Utf8, LargeUtf8]), - TypeSignature::Exact(vec![LargeUtf8, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, Utf8View]), - TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), - TypeSignature::Exact(vec![Utf8View, Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8View, Utf8View, Utf8]), - TypeSignature::Exact(vec![Utf8View, LargeUtf8, Utf8]), - TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8, Utf8View, Utf8]), - TypeSignature::Exact(vec![Utf8, LargeUtf8, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, Utf8, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, Utf8View, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Utf8]), - ], + vec![TypeSignature::String(2), TypeSignature::String(3)], Volatility::Immutable, ), } @@ -211,7 +192,34 @@ pub fn regexp_like(args: &[ArrayRef]) -> Result { match args.len() { 2 => handle_regexp_like(&args[0], &args[1], None), 3 => { - let flags = args[2].as_string::(); + let flags = match args[2].data_type() { + Utf8 => args[2].as_string::(), + LargeUtf8 => { + let large_string_array = args[2].as_string::(); + let string_vec: Vec> = (0..large_string_array.len()).map(|i| { + if large_string_array.is_null(i) { + None + } else { + Some(large_string_array.value(i)) + } + }) + .collect(); + + &GenericStringArray::::from(string_vec) + }, + _ => { + let string_view_array = args[2].as_string_view(); + let string_vec: Vec> = (0..string_view_array.len()).map(|i| { + if string_view_array.is_null(i) { + None + } else { + Some(string_view_array.value(i).to_string()) + } + }) + .collect(); + &GenericStringArray::::from(string_vec) + }, + }; if flags.iter().any(|s| s == Some("g")) { return plan_err!("regexp_like() does not support the \"global\" option"); diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index ebabaf7655ff..c37dd1ed3b4f 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -731,7 +731,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: regexp_like(test.column1_utf8view, Utf8("^https?://(?:www\.)?([^/]+)/.*$")) AS k +01)Projection: regexp_like(test.column1_utf8view, Utf8View("^https?://(?:www\.)?([^/]+)/.*$")) AS k 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for REGEXP_MATCH From cebf94fa8a47446994ab61d4ea580d414569119d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 8 Dec 2024 08:28:05 -0500 Subject: [PATCH 16/35] Performance: enable array allocation reuse (`ScalarFunctionArgs` gets owned `ColumnReference`) (#13637) * Improve documentation * Pass owned args to ScalarFunctionArgs * Update advanced_udf with example of reusing arrays * clarify rationale for cloning * clarify comments * fix expected output --- datafusion-examples/examples/advanced_udf.rs | 126 ++++++++++++++---- datafusion/expr/src/udf.rs | 14 +- .../functions/src/datetime/to_local_time.rs | 2 +- datafusion/functions/src/string/ascii.rs | 6 +- datafusion/functions/src/string/btrim.rs | 24 ++-- datafusion/functions/src/string/concat.rs | 12 +- datafusion/functions/src/string/concat_ws.rs | 8 +- datafusion/functions/src/string/ends_with.rs | 8 +- datafusion/functions/src/string/initcap.rs | 16 +-- datafusion/functions/src/string/ltrim.rs | 22 +-- .../functions/src/string/octet_length.rs | 20 +-- datafusion/functions/src/string/repeat.rs | 12 +- datafusion/functions/src/string/replace.rs | 6 +- datafusion/functions/src/string/rtrim.rs | 22 +-- datafusion/functions/src/string/split_part.rs | 8 +- .../functions/src/string/starts_with.rs | 2 +- .../functions/src/unicode/character_length.rs | 6 +- datafusion/functions/src/unicode/left.rs | 18 +-- datafusion/functions/src/unicode/lpad.rs | 24 ++-- datafusion/functions/src/unicode/reverse.rs | 6 +- datafusion/functions/src/unicode/right.rs | 18 +-- datafusion/functions/src/unicode/rpad.rs | 28 ++-- datafusion/functions/src/unicode/strpos.rs | 2 +- datafusion/functions/src/unicode/substr.rs | 56 ++++---- .../functions/src/unicode/substrindex.rs | 14 +- datafusion/functions/src/unicode/translate.rs | 12 +- .../physical-expr/src/scalar_function.rs | 8 +- 27 files changed, 288 insertions(+), 212 deletions(-) diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs index aee3be6c9285..ae35cff6facf 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/advanced_udf.rs @@ -27,9 +27,11 @@ use arrow::record_batch::RecordBatch; use datafusion::error::Result; use datafusion::logical_expr::Volatility; use datafusion::prelude::*; -use datafusion_common::{internal_err, ScalarValue}; +use datafusion_common::{exec_err, internal_err, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, +}; /// This example shows how to use the full ScalarUDFImpl API to implement a user /// defined function. As in the `simple_udf.rs` example, this struct implements @@ -83,23 +85,27 @@ impl ScalarUDFImpl for PowUdf { Ok(DataType::Float64) } - /// This is the function that actually calculates the results. + /// This function actually calculates the results of the scalar function. + /// + /// This is the same way that functions provided with DataFusion are invoked, + /// which permits important special cases: /// - /// This is the same way that functions built into DataFusion are invoked, - /// which permits important special cases when one or both of the arguments - /// are single values (constants). For example `pow(a, 2)` + ///1. When one or both of the arguments are single values (constants). + /// For example `pow(a, 2)` + /// 2. When the input arrays can be reused (avoid allocating a new output array) /// /// However, it also means the implementation is more complex than when /// using `create_udf`. - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // The other fields of the `args` struct are used for more specialized + // uses, and are not needed in this example + let ScalarFunctionArgs { mut args, .. } = args; // DataFusion has arranged for the correct inputs to be passed to this // function, but we check again to make sure assert_eq!(args.len(), 2); - let (base, exp) = (&args[0], &args[1]); + // take ownership of arguments by popping in reverse order + let exp = args.pop().unwrap(); + let base = args.pop().unwrap(); assert_eq!(base.data_type(), DataType::Float64); assert_eq!(exp.data_type(), DataType::Float64); @@ -118,7 +124,7 @@ impl ScalarUDFImpl for PowUdf { ) => { // compute the output. Note DataFusion treats `None` as NULL. let res = match (base, exp) { - (Some(base), Some(exp)) => Some(base.powf(*exp)), + (Some(base), Some(exp)) => Some(base.powf(exp)), // one or both arguments were NULL _ => None, }; @@ -140,31 +146,33 @@ impl ScalarUDFImpl for PowUdf { // kernel creates very fast "vectorized" code and // handles things like null values for us. let res: Float64Array = - compute::unary(base_array, |base| base.powf(*exp)); + compute::unary(base_array, |base| base.powf(exp)); Arc::new(res) } }; Ok(ColumnarValue::Array(result_array)) } - // special case if the base is a constant (note this code is quite - // similar to the previous case, so we omit comments) + // special case if the base is a constant. + // + // Note this case is very similar to the previous case, so we could + // use the same pattern. However, for this case we demonstrate an + // even more advanced pattern to potentially avoid allocating a new array ( ColumnarValue::Scalar(ScalarValue::Float64(base)), ColumnarValue::Array(exp_array), ) => { let res = match base { None => new_null_array(exp_array.data_type(), exp_array.len()), - Some(base) => { - let exp_array = exp_array.as_primitive::(); - let res: Float64Array = - compute::unary(exp_array, |exp| base.powf(exp)); - Arc::new(res) - } + Some(base) => maybe_pow_in_place(base, exp_array)?, }; Ok(ColumnarValue::Array(res)) } - // Both arguments are arrays so we have to perform the calculation for every row + // Both arguments are arrays so we have to perform the calculation + // for every row + // + // Note this could also be done in place using `binary_mut` as + // is done in `maybe_pow_in_place` but here we use binary for simplicity (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => { let res: Float64Array = compute::binary( base_array.as_primitive::(), @@ -191,6 +199,52 @@ impl ScalarUDFImpl for PowUdf { } } +/// Evaluate `base ^ exp` *without* allocating a new array, if possible +fn maybe_pow_in_place(base: f64, exp_array: ArrayRef) -> Result { + // Calling `unary` creates a new array for the results. Avoiding + // allocations is a common optimization in performance critical code. + // arrow-rs allows this optimization via the `unary_mut` + // and `binary_mut` kernels in certain cases + // + // These kernels can only be used if there are no other references to + // the arrays (exp_array has to be the last remaining reference). + let owned_array = exp_array + // as in the previous example, we first downcast to &Float64Array + .as_primitive::() + // non-obviously, we call clone here to get an owned `Float64Array`. + // Calling clone() is relatively inexpensive as it increments + // some ref counts but doesn't clone the data) + // + // Once we have the owned Float64Array we can drop the original + // exp_array (untyped) reference + .clone(); + + // We *MUST* drop the reference to `exp_array` explicitly so that + // owned_array is the only reference remaining in this function. + // + // Note that depending on the query there may still be other references + // to the underlying buffers, which would prevent reuse. The only way to + // know for sure is the result of `compute::unary_mut` + drop(exp_array); + + // If we have the only reference, compute the result directly into the same + // allocation as was used for the input array + match compute::unary_mut(owned_array, |exp| base.powf(exp)) { + Err(_orig_array) => { + // unary_mut will return the original array if there are other + // references into the underling buffer (and thus reuse is + // impossible) + // + // In a real implementation, this case should fall back to + // calling `unary` and allocate a new array; In this example + // we will return an error for demonstration purposes + exec_err!("Could not reuse array for maybe_pow_in_place") + } + // a result of OK means the operation was run successfully + Ok(res) => Ok(Arc::new(res)), + } +} + /// In this example we register `PowUdf` as a user defined function /// and invoke it via the DataFrame API and SQL #[tokio::main] @@ -215,9 +269,29 @@ async fn main() -> Result<()> { // print the results df.show().await?; - // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL - let sql_df = ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t").await?; - sql_df.show().await?; + // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL + ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t") + .await? + .show() + .await?; + + // You can also invoke pow_in_place by passing a constant base and a + // column `a` as the exponent . If there is only a single + // reference to `a` the code works well + ctx.sql("SELECT pow(2, a) FROM t").await?.show().await?; + + // However, if there are multiple references to `a` in the evaluation + // the array storage can not be reused + let err = ctx + .sql("SELECT pow(2, a), pow(3, a) FROM t") + .await? + .show() + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Execution error: Could not reuse array for maybe_pow_in_place" + ); Ok(()) } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index bf9c9f407eff..809c78f30eff 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -326,13 +326,15 @@ where } } +/// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a +/// scalar function. pub struct ScalarFunctionArgs<'a> { - // The evaluated arguments to the function - pub args: &'a [ColumnarValue], - // The number of rows in record batch being evaluated + /// The evaluated arguments to the function + pub args: Vec, + /// The number of rows in record batch being evaluated pub number_rows: usize, - // The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`) - // when creating the physical expression from the logical expression + /// The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`) + /// when creating the physical expression from the logical expression pub return_type: &'a DataType, } @@ -539,7 +541,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments /// to arrays, which will likely be simpler code, but be slower. fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - self.invoke_batch(args.args, args.number_rows) + self.invoke_batch(&args.args, args.number_rows) } /// Invoke the function without `args`, instead the number of rows are provided, diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index eaa91d1140ba..9f95b780ea4f 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -562,7 +562,7 @@ mod tests { fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { let res = ToLocalTimeFunc::new() .invoke_with_args(ScalarFunctionArgs { - args: &[ColumnarValue::Scalar(input)], + args: vec![ColumnarValue::Scalar(input)], number_rows: 1, return_type: &expected.data_type(), }) diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 4f615b5b2c58..f366329b4f86 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -157,7 +157,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], $EXPECTED, i32, Int32, @@ -166,7 +166,7 @@ mod tests { test_function!( AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], $EXPECTED, i32, Int32, @@ -175,7 +175,7 @@ mod tests { test_function!( AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], $EXPECTED, i32, Int32, diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index ae79bb59f9c7..298d64f04ae9 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -152,9 +152,9 @@ mod tests { // String view cases for checking normal logic test_function!( BTrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("alphabet ") - ))),], + )))], Ok(Some("alphabet")), &str, Utf8View, @@ -162,7 +162,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from(" alphabet ") ))),], Ok(Some("alphabet")), @@ -172,7 +172,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -185,7 +185,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -200,7 +200,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -214,7 +214,7 @@ mod tests { // Special string view case for checking unlined output(len > 12) test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "xxxalphabetalphabetxxx" )))), @@ -228,7 +228,7 @@ mod tests { // String cases test_function!( BTrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet")), @@ -238,7 +238,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet")), @@ -248,7 +248,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t")))), ], @@ -259,7 +259,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), ], @@ -270,7 +270,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 576c891ce467..895a7cdbf308 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -388,7 +388,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::from("bb")), ColumnarValue::Scalar(ScalarValue::from("cc")), @@ -400,7 +400,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("cc")), @@ -412,7 +412,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))], Ok(Some("")), &str, Utf8, @@ -420,7 +420,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::Utf8View(None)), ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)), @@ -433,7 +433,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)), ColumnarValue::Scalar(ScalarValue::from("cc")), @@ -445,7 +445,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))), ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))), ], diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 610c4f0be697..7db8dbec4a71 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -404,7 +404,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("|")), ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::from("bb")), @@ -417,7 +417,7 @@ mod tests { ); test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("|")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], @@ -428,7 +428,7 @@ mod tests { ); test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::from("bb")), @@ -441,7 +441,7 @@ mod tests { ); test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("|")), ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index fc7fc04f4363..1632fdd9943e 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -138,7 +138,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from("alph")), ], @@ -149,7 +149,7 @@ mod tests { ); test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from("bet")), ], @@ -160,7 +160,7 @@ mod tests { ); test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("alph")), ], @@ -171,7 +171,7 @@ mod tests { ); test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/initcap.rs b/datafusion/functions/src/string/initcap.rs index a9090b0a6f43..338a89091d29 100644 --- a/datafusion/functions/src/string/initcap.rs +++ b/datafusion/functions/src/string/initcap.rs @@ -163,7 +163,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))], + vec![ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))], Ok(Some("Hi Thomas")), &str, Utf8, @@ -171,7 +171,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from(""))], + vec![ColumnarValue::Scalar(ScalarValue::from(""))], Ok(Some("")), &str, Utf8, @@ -179,7 +179,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from(""))], + vec![ColumnarValue::Scalar(ScalarValue::from(""))], Ok(Some("")), &str, Utf8, @@ -187,7 +187,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))], Ok(None), &str, Utf8, @@ -195,7 +195,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( "hi THOMAS".to_string() )))], Ok(Some("Hi Thomas")), @@ -205,7 +205,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( "hi THOMAS wIth M0re ThAN 12 ChaRs".to_string() )))], Ok(Some("Hi Thomas With M0re Than 12 Chars")), @@ -215,7 +215,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( "".to_string() )))], Ok(Some("")), @@ -225,7 +225,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(None))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(None))], Ok(None), &str, Utf8, diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index e0e83d1b01e3..b3e7f0bf007d 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -148,7 +148,7 @@ mod tests { // String view cases for checking normal logic test_function!( LtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("alphabet ") ))),], Ok(Some("alphabet ")), @@ -158,7 +158,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from(" alphabet ") ))),], Ok(Some("alphabet ")), @@ -168,7 +168,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -181,7 +181,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -196,7 +196,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -210,7 +210,7 @@ mod tests { // Special string view case for checking unlined output(len > 12) test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "xxxalphabetalphabet" )))), @@ -224,7 +224,7 @@ mod tests { // String cases test_function!( LtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet ")), @@ -234,7 +234,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet ")), @@ -244,7 +244,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t")))), ], @@ -255,7 +255,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), ], @@ -266,7 +266,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index 2dbfa6746d61..26355556ff07 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -140,7 +140,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Int32(Some(12)))], + vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(12)))], exec_err!( "The OCTET_LENGTH function can only accept strings, but got Int32." ), @@ -150,7 +150,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Array(Arc::new(StringArray::from(vec![ + vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![ String::from("chars"), String::from("chars2"), ])))], @@ -161,7 +161,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("chars")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("chars")))) ], @@ -172,7 +172,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("chars") )))], Ok(Some(5)), @@ -182,7 +182,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("josé") )))], Ok(Some(5)), @@ -192,7 +192,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("") )))], Ok(Some(0)), @@ -202,7 +202,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))], Ok(None), i32, Int32, @@ -210,7 +210,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("joséjoséjoséjosé") )))], Ok(Some(20)), @@ -220,7 +220,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("josé") )))], Ok(Some(5)), @@ -230,7 +230,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("") )))], Ok(Some(0)), diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 4140a9b913ff..d16508c6af5a 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -171,7 +171,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ], @@ -182,7 +182,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ], @@ -193,7 +193,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -205,7 +205,7 @@ mod tests { test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ], @@ -216,7 +216,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(None)), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ], @@ -227,7 +227,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index 2439799f96d7..9b71d3871ea8 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -157,7 +157,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( ReplaceFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))), @@ -170,7 +170,7 @@ mod tests { test_function!( ReplaceFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from( "aabbb" )))), @@ -185,7 +185,7 @@ mod tests { test_function!( ReplaceFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "aabbbcw" )))), diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index b4fe8d432432..ff8430f1530e 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -151,7 +151,7 @@ mod tests { // String view cases for checking normal logic test_function!( RtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("alphabet ") ))),], Ok(Some("alphabet")), @@ -161,7 +161,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from(" alphabet ") ))),], Ok(Some(" alphabet")), @@ -171,7 +171,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -184,7 +184,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -199,7 +199,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -213,7 +213,7 @@ mod tests { // Special string view case for checking unlined output(len > 12) test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabetalphabetxxx" )))), @@ -227,7 +227,7 @@ mod tests { // String cases test_function!( RtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet")), @@ -237,7 +237,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from(" alphabet ") ))),], Ok(Some(" alphabet")), @@ -247,7 +247,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t ")))), ], @@ -258,7 +258,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), ], @@ -269,7 +269,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index e55325db756d..40bdd3ad01b2 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -270,7 +270,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), @@ -284,7 +284,7 @@ mod tests { ); test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), @@ -298,7 +298,7 @@ mod tests { ); test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), @@ -312,7 +312,7 @@ mod tests { ); test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 36dbd8167b4e..7354fda09584 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -159,7 +159,7 @@ mod tests { for (args, expected) in test_cases { test_function!( StartsWithFunc::new(), - &args, + args, Ok(expected), bool, Boolean, diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index 726822a8f887..822bdca9aca8 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -176,7 +176,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], $EXPECTED, i32, Int32, @@ -185,7 +185,7 @@ mod tests { test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], $EXPECTED, i64, Int64, @@ -194,7 +194,7 @@ mod tests { test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], $EXPECTED, i32, Int32, diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs index ef2802340b14..e583523d84a0 100644 --- a/datafusion/functions/src/unicode/left.rs +++ b/datafusion/functions/src/unicode/left.rs @@ -188,7 +188,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -199,7 +199,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(200i64)), ], @@ -210,7 +210,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(-2i64)), ], @@ -221,7 +221,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(-200i64)), ], @@ -232,7 +232,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ], @@ -243,7 +243,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -254,7 +254,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -265,7 +265,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -276,7 +276,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(-3i64)), ], diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index 6c8a4ec97bb0..f1750d2277ca 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -298,7 +298,7 @@ mod tests { ($INPUT:expr, $LENGTH:expr, $EXPECTED:expr) => { test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), ColumnarValue::Scalar($LENGTH) ], @@ -310,7 +310,7 @@ mod tests { test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::Scalar($LENGTH) ], @@ -322,7 +322,7 @@ mod tests { test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), ColumnarValue::Scalar($LENGTH) ], @@ -337,7 +337,7 @@ mod tests { // utf8, utf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) @@ -350,7 +350,7 @@ mod tests { // utf8, largeutf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) @@ -363,7 +363,7 @@ mod tests { // utf8, utf8view test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) @@ -377,7 +377,7 @@ mod tests { // largeutf8, utf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) @@ -390,7 +390,7 @@ mod tests { // largeutf8, largeutf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) @@ -403,7 +403,7 @@ mod tests { // largeutf8, utf8view test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) @@ -417,7 +417,7 @@ mod tests { // utf8view, utf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) @@ -430,7 +430,7 @@ mod tests { // utf8view, largeutf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) @@ -443,7 +443,7 @@ mod tests { // utf8view, utf8view test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index 38c1f23cbd5a..8e3cf8845f98 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -151,7 +151,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], $EXPECTED, &str, Utf8, @@ -160,7 +160,7 @@ mod tests { test_function!( ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], $EXPECTED, &str, LargeUtf8, @@ -169,7 +169,7 @@ mod tests { test_function!( ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], $EXPECTED, &str, Utf8, diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs index 1586e23eb8aa..4e414fbae5cb 100644 --- a/datafusion/functions/src/unicode/right.rs +++ b/datafusion/functions/src/unicode/right.rs @@ -192,7 +192,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -203,7 +203,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(200i64)), ], @@ -214,7 +214,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(-2i64)), ], @@ -225,7 +225,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(-200i64)), ], @@ -236,7 +236,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ], @@ -247,7 +247,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -258,7 +258,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -269,7 +269,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -280,7 +280,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(-3i64)), ], diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index 6e6bde3e177c..d5a0079c72aa 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -319,7 +319,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("josé")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -330,7 +330,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -341,7 +341,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ], @@ -352,7 +352,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -363,7 +363,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -374,7 +374,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from("xy")), @@ -386,7 +386,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(21i64)), ColumnarValue::Scalar(ScalarValue::from("abcdef")), @@ -398,7 +398,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from(" ")), @@ -410,7 +410,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from("")), @@ -422,7 +422,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from("xy")), @@ -434,7 +434,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ColumnarValue::Scalar(ScalarValue::from("xy")), @@ -446,7 +446,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::Utf8(None)), @@ -458,7 +458,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("josé")), ColumnarValue::Scalar(ScalarValue::from(10i64)), ColumnarValue::Scalar(ScalarValue::from("xy")), @@ -470,7 +470,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("josé")), ColumnarValue::Scalar(ScalarValue::from(10i64)), ColumnarValue::Scalar(ScalarValue::from("éñ")), diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 5d1986e44c92..569af87a4b50 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -218,7 +218,7 @@ mod tests { ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => { test_function!( StrposFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))), ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))), ], diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 0ac050c707bf..141984cf2674 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -523,7 +523,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(None)), ColumnarValue::Scalar(ScalarValue::from(1i64)), ], @@ -534,7 +534,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -547,7 +547,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "this és longer than 12B" )))), @@ -561,7 +561,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "this is longer than 12B" )))), @@ -574,7 +574,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "joséésoj" )))), @@ -587,7 +587,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -601,7 +601,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -615,7 +615,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ], @@ -626,7 +626,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -637,7 +637,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(-5i64)), ], @@ -648,7 +648,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(1i64)), ], @@ -659,7 +659,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -670,7 +670,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(3i64)), ], @@ -681,7 +681,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(-3i64)), ], @@ -692,7 +692,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(30i64)), ], @@ -703,7 +703,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -714,7 +714,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(3i64)), ColumnarValue::Scalar(ScalarValue::from(2i64)), @@ -726,7 +726,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(3i64)), ColumnarValue::Scalar(ScalarValue::from(20i64)), @@ -738,7 +738,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ColumnarValue::Scalar(ScalarValue::from(5i64)), @@ -751,7 +751,7 @@ mod tests { // starting from 5 (10 + -5) test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(-5i64)), ColumnarValue::Scalar(ScalarValue::from(10i64)), @@ -764,7 +764,7 @@ mod tests { // starting from -1 (4 + -5) test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(-5i64)), ColumnarValue::Scalar(ScalarValue::from(4i64)), @@ -777,7 +777,7 @@ mod tests { // starting from 0 (5 + -5) test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(-5i64)), ColumnarValue::Scalar(ScalarValue::from(5i64)), @@ -789,7 +789,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ColumnarValue::Scalar(ScalarValue::from(20i64)), @@ -801,7 +801,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(3i64)), ColumnarValue::Scalar(ScalarValue::Int64(None)), @@ -813,7 +813,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(1i64)), ColumnarValue::Scalar(ScalarValue::from(-1i64)), @@ -825,7 +825,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from(2i64)), @@ -851,7 +851,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abc")), ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), ], @@ -862,7 +862,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("overflow")), ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), ColumnarValue::Scalar(ScalarValue::from(1i64)), diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index 825666b0455e..61cd989bb964 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -253,7 +253,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(1i64)), @@ -265,7 +265,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(2i64)), @@ -277,7 +277,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(-2i64)), @@ -289,7 +289,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(-1i64)), @@ -301,7 +301,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(0i64)), @@ -313,7 +313,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(1i64)), @@ -325,7 +325,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from("")), ColumnarValue::Scalar(ScalarValue::from(1i64)), diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index 780603777133..9257b0b04e61 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -201,7 +201,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("12345")), ColumnarValue::Scalar(ScalarValue::from("143")), ColumnarValue::Scalar(ScalarValue::from("ax")) @@ -213,7 +213,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("143")), ColumnarValue::Scalar(ScalarValue::from("ax")) @@ -225,7 +225,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("12345")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("ax")) @@ -237,7 +237,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("12345")), ColumnarValue::Scalar(ScalarValue::from("143")), ColumnarValue::Scalar(ScalarValue::Utf8(None)) @@ -249,7 +249,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("é2íñ5")), ColumnarValue::Scalar(ScalarValue::from("éñí")), ColumnarValue::Scalar(ScalarValue::from("óü")), @@ -262,7 +262,7 @@ mod tests { #[cfg(not(feature = "unicode_expressions"))] test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("12345")), ColumnarValue::Scalar(ScalarValue::from("143")), ColumnarValue::Scalar(ScalarValue::from("ax")), diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 45f77325eea3..e312d5de59fb 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -134,20 +134,20 @@ impl PhysicalExpr for ScalarFunctionExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let inputs = self + let args = self .args .iter() .map(|e| e.evaluate(batch)) .collect::>>()?; - let input_empty = inputs.is_empty(); - let input_all_scalar = inputs + let input_empty = args.is_empty(); + let input_all_scalar = args .iter() .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); // evaluate the function let output = self.fun.invoke_with_args(ScalarFunctionArgs { - args: inputs.as_slice(), + args, number_rows: batch.num_rows(), return_type: &self.return_type, })?; From 6dd3f3a87da75547025f07b8e163191316215405 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Sun, 8 Dec 2024 21:58:18 +0800 Subject: [PATCH 17/35] Temporary fix for CI (#13689) --- datafusion/core/Cargo.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 4706afc897c2..4583b84cdae6 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -132,6 +132,10 @@ xz2 = { version = "0.1", optional = true, features = ["static"] } zstd = { version = "0.13", optional = true, default-features = false } [dev-dependencies] +# Temporary fix for https://github.com/apache/datafusion/issues/13686 +# TODO: Remove it once the upstream has a fix +lexical-write-integer = { version = "=1.0.2" } + arrow-buffer = { workspace = true } async-trait = { workspace = true } criterion = { version = "0.5", features = ["async_tokio"] } From de36fb6f67292196f45813ef5b87db4510aab5c9 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Sun, 8 Dec 2024 23:37:53 +0800 Subject: [PATCH 18/35] refactor: use `LazyLock` in the `user_doc` macro (#13684) * refactor: use `LazyLock` in the `user_doc` macro * Fix cargo doc * Update datafusion/macros/src/lib.rs * Fix doc comment --------- Co-authored-by: Oleks V --- .../src/approx_distinct.rs | 2 - .../functions-aggregate/src/approx_median.rs | 2 - .../src/approx_percentile_cont.rs | 3 +- .../src/approx_percentile_cont_with_weight.rs | 3 +- .../functions-aggregate/src/array_agg.rs | 3 +- datafusion/functions-aggregate/src/average.rs | 3 +- .../functions-aggregate/src/bool_and_or.rs | 2 - .../functions-aggregate/src/correlation.rs | 3 +- datafusion/functions-aggregate/src/count.rs | 3 +- .../functions-aggregate/src/covariance.rs | 2 - .../functions-aggregate/src/first_last.rs | 3 +- .../functions-aggregate/src/grouping.rs | 2 - datafusion/functions-aggregate/src/median.rs | 3 +- datafusion/functions-aggregate/src/min_max.rs | 2 - .../functions-aggregate/src/nth_value.rs | 3 +- datafusion/functions-aggregate/src/stddev.rs | 3 +- .../functions-aggregate/src/string_agg.rs | 2 - datafusion/functions-aggregate/src/sum.rs | 2 - .../functions-aggregate/src/variance.rs | 2 - datafusion/functions/src/datetime/to_date.rs | 2 - datafusion/functions/src/math/abs.rs | 3 +- datafusion/functions/src/string/ltrim.rs | 2 - datafusion/macros/src/user_doc.rs | 105 ++++++++++-------- 23 files changed, 68 insertions(+), 92 deletions(-) diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 74691ba740fd..1d378fff176f 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -31,7 +31,6 @@ use datafusion_common::ScalarValue; use datafusion_common::{ downcast_value, internal_err, not_impl_err, DataFusionError, Result, }; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -42,7 +41,6 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::hash::Hash; use std::marker::PhantomData; -use std::sync::OnceLock; make_udaf_expr_and_func!( ApproxDistinct, diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index d4441da61292..5d174a752296 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -19,13 +19,11 @@ use std::any::Any; use std::fmt::Debug; -use std::sync::OnceLock; use arrow::{datatypes::DataType, datatypes::Field}; use arrow_schema::DataType::{Float64, UInt64}; use datafusion_common::{not_impl_err, plan_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 13407fecf220..61424e8f2445 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::mem::size_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{Array, RecordBatch}; use arrow::compute::{filter, is_not_null}; @@ -35,7 +35,6 @@ use datafusion_common::{ downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index 485874aeb284..10b9b06f1f94 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::mem::size_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::{ array::ArrayRef, @@ -27,7 +27,6 @@ use arrow::{ use datafusion_common::ScalarValue; use datafusion_common::{not_impl_err, plan_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::Volatility::Immutable; diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 98530a9fc236..b75de83f6ace 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -25,7 +25,6 @@ use datafusion_common::cast::as_list_array; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{internal_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{Accumulator, Signature, Volatility}; @@ -36,7 +35,7 @@ use datafusion_macros::user_doc; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use std::collections::{HashSet, VecDeque}; use std::mem::{size_of, size_of_val}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; make_udaf_expr_and_func!( ArrayAgg, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 65ca441517a0..18874f831e9d 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -42,14 +42,13 @@ use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls: filtered_null_mask, set_nulls, }; -use datafusion_doc::DocSection; use datafusion_functions_aggregate_common::utils::DecimalAverager; use datafusion_macros::user_doc; use log::debug; use std::any::Any; use std::fmt::Debug; use std::mem::{size_of, size_of_val}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; make_udaf_expr_and_func!( Avg, diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index 1b5b20f43b3e..29dfc68e0576 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -19,7 +19,6 @@ use std::any::Any; use std::mem::size_of_val; -use std::sync::OnceLock; use arrow::array::ArrayRef; use arrow::array::BooleanArray; @@ -38,7 +37,6 @@ use datafusion_expr::{ Signature, Volatility, }; -use datafusion_doc::DocSection; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; use datafusion_macros::user_doc; diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index b40555bf6c7f..a0ccdb0ae7d0 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -20,7 +20,7 @@ use std::any::Any; use std::fmt::Debug; use std::mem::size_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::compute::{and, filter, is_not_null}; use arrow::{ @@ -31,7 +31,6 @@ use arrow::{ use crate::covariance::CovarianceAccumulator; use crate::stddev::StddevAccumulator; use datafusion_common::{plan_err, Result, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, type_coercion::aggregates::NUMERICS, diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 550df8cb4f7d..b4164c211c35 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -17,7 +17,6 @@ use ahash::RandomState; use datafusion_common::stats::Precision; -use datafusion_doc::DocSection; use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; use datafusion_macros::user_doc; use datafusion_physical_expr::expressions; @@ -25,7 +24,7 @@ use std::collections::HashSet; use std::fmt::Debug; use std::mem::{size_of, size_of_val}; use std::ops::BitAnd; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::{ array::{ArrayRef, AsArray}, diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index adb546e4d906..ffbf2ceef052 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -19,7 +19,6 @@ use std::fmt::Debug; use std::mem::size_of_val; -use std::sync::OnceLock; use arrow::{ array::{ArrayRef, Float64Array, UInt64Array}, @@ -31,7 +30,6 @@ use datafusion_common::{ downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_doc::DocSection; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, type_coercion::aggregates::NUMERICS, diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index f3e66edbc009..9ad55d91a68b 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -20,7 +20,7 @@ use std::any::Any; use std::fmt::Debug; use std::mem::size_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, BooleanArray}; use arrow::compute::{self, lexsort_to_indices, take_arrays, SortColumn}; @@ -29,7 +29,6 @@ use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 36bdf68c1b0e..445774ff11e7 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -19,12 +19,10 @@ use std::any::Any; use std::fmt; -use std::sync::OnceLock; use arrow::datatypes::DataType; use arrow::datatypes::Field; use datafusion_common::{not_impl_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index db5fbf00165f..70f192c32ae1 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -18,7 +18,7 @@ use std::cmp::Ordering; use std::fmt::{Debug, Formatter}; use std::mem::{size_of, size_of_val}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{downcast_integer, ArrowNumericType}; use arrow::{ @@ -34,7 +34,6 @@ use arrow::array::ArrowNativeTypeOp; use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType}; use datafusion_common::{DataFusionError, HashSet, Result, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index acbeebaad68b..a0f7634c5fa8 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -55,7 +55,6 @@ use arrow::datatypes::{ use crate::min_max::min_max_bytes::MinMaxBytesAccumulator; use datafusion_common::ScalarValue; -use datafusion_doc::DocSection; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, @@ -65,7 +64,6 @@ use datafusion_macros::user_doc; use half::f16; use std::mem::size_of_val; use std::ops::Deref; -use std::sync::OnceLock; fn get_min_max_result_type(input_types: &[DataType]) -> Result> { // make sure that the input types only has one element. diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 15b9e97516ca..8252fd6baaa3 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -21,14 +21,13 @@ use std::any::Any; use std::collections::VecDeque; use std::mem::{size_of, size_of_val}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; use arrow_schema::{DataType, Field, Fields}; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 09a39e342cce..adf86a128cfb 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -20,14 +20,13 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::mem::align_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::Float64Array; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_common::{plan_err, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 5a52bec55f15..7643b44e11d5 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -22,7 +22,6 @@ use arrow_schema::DataType; use datafusion_common::cast::as_generic_string_array; use datafusion_common::Result; use datafusion_common::{not_impl_err, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, @@ -31,7 +30,6 @@ use datafusion_macros::user_doc; use datafusion_physical_expr::expressions::Literal; use std::any::Any; use std::mem::size_of_val; -use std::sync::OnceLock; make_udaf_expr_and_func!( StringAgg, diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index ccc6ee3cf925..6c2854f6bc24 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -22,7 +22,6 @@ use datafusion_expr::utils::AggregateOrderSensitivity; use std::any::Any; use std::collections::HashSet; use std::mem::{size_of, size_of_val}; -use std::sync::OnceLock; use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; @@ -35,7 +34,6 @@ use arrow::datatypes::{ }; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 70b10734088f..8aa7a40ce320 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -25,13 +25,11 @@ use arrow::{ datatypes::{DataType, Field}, }; use std::mem::{size_of, size_of_val}; -use std::sync::OnceLock; use std::{fmt::Debug, sync::Arc}; use datafusion_common::{ downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; -use datafusion_doc::DocSection; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index e2edea843e98..091d0ba37644 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -22,13 +22,11 @@ use arrow::error::ArrowError::ParseError; use arrow::{array::types::Date32Type, compute::kernels::cast_utils::Parser}; use datafusion_common::error::DataFusionError; use datafusion_common::{arrow_err, exec_err, internal_datafusion_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; -use std::sync::OnceLock; #[user_doc( doc_section(label = "Time and Date Functions"), diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index c0c7c6f0f6b6..e3d448083e26 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -18,7 +18,7 @@ //! math expressions use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, @@ -27,7 +27,6 @@ use arrow::array::{ use arrow::datatypes::DataType; use arrow::error::ArrowError; use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result}; -use datafusion_doc::DocSection; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index b3e7f0bf007d..0bc62ee5000d 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -22,12 +22,10 @@ use std::any::Any; use crate::string::common::*; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::Hint; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; -use std::sync::OnceLock; /// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. /// ltrim('zzzytest', 'xyz') = 'test' diff --git a/datafusion/macros/src/user_doc.rs b/datafusion/macros/src/user_doc.rs index 54b688ac2a49..441b3db2a133 100644 --- a/datafusion/macros/src/user_doc.rs +++ b/datafusion/macros/src/user_doc.rs @@ -26,16 +26,19 @@ use syn::{parse_macro_input, DeriveInput, LitStr}; /// declared on `AggregateUDF`, `WindowUDFImpl`, `ScalarUDFImpl` traits. /// /// Example: +/// ```ignore /// #[user_doc( /// doc_section(include = "true", label = "Time and Date Functions"), -/// description = r"Converts a value to a date (`YYYY-MM-DD`)." -/// sql_example = "```sql\n\ -/// \> select to_date('2023-01-31');\n\ -/// +-----------------------------+\n\ -/// | to_date(Utf8(\"2023-01-31\")) |\n\ -/// +-----------------------------+\n\ -/// | 2023-01-31 |\n\ -/// +-----------------------------+\n\"), +/// description = r"Converts a value to a date (`YYYY-MM-DD`).", +/// syntax_example = "to_date('2017-05-31', '%Y-%m-%d')", +/// sql_example = r#"```sql +/// > select to_date('2023-01-31'); +/// +-----------------------------+ +/// | to_date(Utf8(\"2023-01-31\")) | +/// +-----------------------------+ +/// | 2023-01-31 | +/// +-----------------------------+ +/// ```"#, /// standard_argument(name = "expression", prefix = "String"), /// argument( /// name = "format_n", @@ -48,40 +51,50 @@ use syn::{parse_macro_input, DeriveInput, LitStr}; /// pub struct ToDateFunc { /// signature: Signature, /// } -/// +/// ``` /// will generate the following code /// -/// #[derive(Debug)] pub struct ToDateFunc { signature : Signature, } -/// use datafusion_doc :: DocSection; -/// use datafusion_doc :: DocumentationBuilder; -/// static DOCUMENTATION : OnceLock < Documentation > = OnceLock :: new(); -/// impl ToDateFunc -/// { -/// fn doc(& self) -> Option < & Documentation > -/// { -/// Some(DOCUMENTATION.get_or_init(|| -/// { -/// Documentation :: -/// builder(DocSection -/// { -/// include : true, label : "Time and Date Functions", description -/// : None -/// }, r"Converts a value to a date (`YYYY-MM-DD`).") -/// .with_syntax_example("to_date('2017-05-31', '%Y-%m-%d')".to_string(),"```sql\n\ -/// \> select to_date('2023-01-31');\n\ -/// +-----------------------------+\n\ -/// | to_date(Utf8(\"2023-01-31\")) |\n\ -/// +-----------------------------+\n\ -/// | 2023-01-31 |\n\ -/// +-----------------------------+\n\) -/// .with_standard_argument("expression", "String".into()) -/// .with_argument("format_n", -/// r"Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order -/// they appear with the first successful one being returned. If none of the formats successfully parse the expression -/// an error will be returned.").build() -/// })) +/// ```ignore +/// pub struct ToDateFunc { +/// signature: Signature, +/// } +/// impl ToDateFunc { +/// fn doc(&self) -> Option<&datafusion_doc::Documentation> { +/// static DOCUMENTATION: std::sync::LazyLock< +/// datafusion_doc::Documentation, +/// > = std::sync::LazyLock::new(|| { +/// datafusion_doc::Documentation::builder( +/// datafusion_doc::DocSection { +/// include: true, +/// label: "Time and Date Functions", +/// description: None, +/// }, +/// r"Converts a value to a date (`YYYY-MM-DD`).".to_string(), +/// "to_date('2017-05-31', '%Y-%m-%d')".to_string(), +/// ) +/// .with_sql_example( +/// r#"```sql +/// > select to_date('2023-01-31'); +/// +-----------------------------+ +/// | to_date(Utf8(\"2023-01-31\")) | +/// +-----------------------------+ +/// | 2023-01-31 | +/// +-----------------------------+ +/// ```"#, +/// ) +/// .with_standard_argument("expression", "String".into()) +/// .with_argument( +/// "format_n", +/// r"Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order +/// they appear with the first successful one being returned. If none of the formats successfully parse the expression +/// an error will be returned.", +/// ) +/// .build() +/// }); +/// Some(&DOCUMENTATION) /// } /// } +/// ``` #[proc_macro_attribute] pub fn user_doc(args: TokenStream, input: TokenStream) -> TokenStream { let mut doc_section_include: Option = None; @@ -235,19 +248,14 @@ pub fn user_doc(args: TokenStream, input: TokenStream) -> TokenStream { } }); - let lock_name: proc_macro2::TokenStream = - format!("{name}_DOCUMENTATION").parse().unwrap(); - let generated = quote! { #input - static #lock_name: OnceLock = OnceLock::new(); - impl #name { - - fn doc(&self) -> Option<&Documentation> { - Some(#lock_name.get_or_init(|| { - Documentation::builder(DocSection { include: #doc_section_include, label: #doc_section_lbl, description: #doc_section_description }, + fn doc(&self) -> Option<&datafusion_doc::Documentation> { + static DOCUMENTATION: std::sync::LazyLock = + std::sync::LazyLock::new(|| { + datafusion_doc::Documentation::builder(datafusion_doc::DocSection { include: #doc_section_include, label: #doc_section_lbl, description: #doc_section_description }, #description.to_string(), #syntax_example.to_string()) #sql_example #alt_syntax_example @@ -255,7 +263,8 @@ pub fn user_doc(args: TokenStream, input: TokenStream) -> TokenStream { #(#udf_args)* #(#related_udfs)* .build() - })) + }); + Some(&DOCUMENTATION) } } }; From 7728525c75194c95c18ed9cbcba9eb8be609bb1f Mon Sep 17 00:00:00 2001 From: Alexander Huszagh Date: Sun, 8 Dec 2024 19:51:44 -0600 Subject: [PATCH 19/35] Unlock lexical-write-integer version. (#13693) Issue was patched as of lexical release 1.0.5. Reverts #13689 Closes #13686 --- datafusion/core/Cargo.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 4583b84cdae6..4706afc897c2 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -132,10 +132,6 @@ xz2 = { version = "0.1", optional = true, features = ["static"] } zstd = { version = "0.13", optional = true, default-features = false } [dev-dependencies] -# Temporary fix for https://github.com/apache/datafusion/issues/13686 -# TODO: Remove it once the upstream has a fix -lexical-write-integer = { version = "=1.0.2" } - arrow-buffer = { workspace = true } async-trait = { workspace = true } criterion = { version = "0.5", features = ["async_tokio"] } From 4884ac28ca131e4f1c54eede343870b3db086642 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 9 Dec 2024 04:05:37 -0800 Subject: [PATCH 20/35] Minor: Use `div_ceil` --- .../physical-plan/src/joins/cross_join.rs | 27 ++++++++++--------- .../src/joins/nested_loop_join.rs | 24 ++++++++--------- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index f53fe13df15e..8bf675e87362 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -190,18 +190,21 @@ async fn load_left_input( // Load all batches and count the rows let (batches, _metrics, reservation) = stream - .try_fold((Vec::new(), metrics, reservation), |mut acc, batch| async { - let batch_size = batch.get_array_memory_size(); - // Reserve memory for incoming batch - acc.2.try_grow(batch_size)?; - // Update metrics - acc.1.build_mem_used.add(batch_size); - acc.1.build_input_batches.add(1); - acc.1.build_input_rows.add(batch.num_rows()); - // Push batch to output - acc.0.push(batch); - Ok(acc) - }) + .try_fold( + (Vec::new(), metrics, reservation), + |(mut batches, metrics, mut reservation), batch| async { + let batch_size = batch.get_array_memory_size(); + // Reserve memory for incoming batch + reservation.try_grow(batch_size)?; + // Update metrics + metrics.build_mem_used.add(batch_size); + metrics.build_input_batches.add(1); + metrics.build_input_rows.add(batch.num_rows()); + // Push batch to output + batches.push(batch); + Ok((batches, metrics, reservation)) + }, + ) .await?; let merged_batch = concat_batches(&left_schema, &batches)?; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 2beeb92da499..d174564178df 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -45,7 +45,6 @@ use arrow::array::{BooleanBufferBuilder, UInt32Array, UInt64Array}; use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow::util::bit_util; use datafusion_common::{ exec_datafusion_err, internal_err, JoinSide, Result, Statistics, }; @@ -440,17 +439,17 @@ async fn collect_left_input( let (batches, metrics, mut reservation) = stream .try_fold( (Vec::new(), join_metrics, reservation), - |mut acc, batch| async { + |(mut batches, metrics, mut reservation), batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch - acc.2.try_grow(batch_size)?; + reservation.try_grow(batch_size)?; // Update metrics - acc.1.build_mem_used.add(batch_size); - acc.1.build_input_batches.add(1); - acc.1.build_input_rows.add(batch.num_rows()); + metrics.build_mem_used.add(batch_size); + metrics.build_input_batches.add(1); + metrics.build_input_rows.add(batch.num_rows()); // Push batch to output - acc.0.push(batch); - Ok(acc) + batches.push(batch); + Ok((batches, metrics, reservation)) }, ) .await?; @@ -459,14 +458,13 @@ async fn collect_left_input( // Reserve memory for visited_left_side bitmap if required by join type let visited_left_side = if with_visited_left_side { - // TODO: Replace `ceil` wrapper with stable `div_cell` after - // https://github.com/rust-lang/rust/issues/88581 - let buffer_size = bit_util::ceil(merged_batch.num_rows(), 8); + let n_rows = merged_batch.num_rows(); + let buffer_size = n_rows.div_ceil(8); reservation.try_grow(buffer_size)?; metrics.build_mem_used.add(buffer_size); - let mut buffer = BooleanBufferBuilder::new(merged_batch.num_rows()); - buffer.append_n(merged_batch.num_rows(), false); + let mut buffer = BooleanBufferBuilder::new(n_rows); + buffer.append_n(n_rows, false); buffer } else { BooleanBufferBuilder::new(0) From 021a500b81aa291002342aa138abe8fad5896104 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Tue, 10 Dec 2024 04:26:19 +0800 Subject: [PATCH 21/35] Fix hash join with sort push down (#13560) * fix: join with sort push down * chore: insert some value * apply suggestion * recover handle_costom_pushdown change * apply suggestion * add more test * add partition --- .../src/physical_optimizer/sort_pushdown.rs | 101 +++++++++++ datafusion/sqllogictest/test_files/joins.slt | 171 +++++++++++++----- 2 files changed, 228 insertions(+), 44 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index d48c7118cb8e..6c761f674b3b 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -28,6 +28,7 @@ use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::tree_node::PlanContext; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use arrow_schema::SchemaRef; use datafusion_common::tree_node::{ ConcreteTreeNode, Transformed, TreeNode, TreeNodeRecursion, @@ -38,6 +39,8 @@ use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::PhysicalSortRequirement; use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_plan::joins::utils::ColumnIndex; +use datafusion_physical_plan::joins::HashJoinExec; /// This is a "data class" we use within the [`EnforceSorting`] rule to push /// down [`SortExec`] in the plan. In some cases, we can reduce the total @@ -294,6 +297,8 @@ fn pushdown_requirement_to_children( .then(|| LexRequirement::new(parent_required.to_vec())); Ok(Some(vec![req])) } + } else if let Some(hash_join) = plan.as_any().downcast_ref::() { + handle_hash_join(hash_join, parent_required) } else { handle_custom_pushdown(plan, parent_required, maintains_input_order) } @@ -606,6 +611,102 @@ fn handle_custom_pushdown( } } +// For hash join we only maintain the input order for the right child +// for join type: Inner, Right, RightSemi, RightAnti +fn handle_hash_join( + plan: &HashJoinExec, + parent_required: &LexRequirement, +) -> Result>>> { + // If there's no requirement from the parent or the plan has no children + // or the join type is not Inner, Right, RightSemi, RightAnti, return early + if parent_required.is_empty() || !plan.maintains_input_order()[1] { + return Ok(None); + } + + // Collect all unique column indices used in the parent-required sorting expression + let all_indices: HashSet = parent_required + .iter() + .flat_map(|order| { + collect_columns(&order.expr) + .into_iter() + .map(|col| col.index()) + .collect::>() + }) + .collect(); + + let column_indices = build_join_column_index(plan); + let projected_indices: Vec<_> = if let Some(projection) = &plan.projection { + projection.iter().map(|&i| &column_indices[i]).collect() + } else { + column_indices.iter().collect() + }; + let len_of_left_fields = projected_indices + .iter() + .filter(|ci| ci.side == JoinSide::Left) + .count(); + + let all_from_right_child = all_indices.iter().all(|i| *i >= len_of_left_fields); + + // If all columns are from the right child, update the parent requirements + if all_from_right_child { + // Transform the parent-required expression for the child schema by adjusting columns + let updated_parent_req = parent_required + .iter() + .map(|req| { + let child_schema = plan.children()[1].schema(); + let updated_columns = Arc::clone(&req.expr) + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + let index = projected_indices[col.index()].index; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(index).name(), + index, + )))) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; + Ok(PhysicalSortRequirement::new(updated_columns, req.options)) + }) + .collect::>>()?; + + // Populating with the updated requirements for children that maintain order + Ok(Some(vec![ + None, + Some(LexRequirement::new(updated_parent_req)), + ])) + } else { + Ok(None) + } +} + +// this function is used to build the column index for the hash join +// push down sort requirements to the right child +fn build_join_column_index(plan: &HashJoinExec) -> Vec { + let map_fields = |schema: SchemaRef, side: JoinSide| { + schema + .fields() + .iter() + .enumerate() + .map(|(index, _)| ColumnIndex { index, side }) + .collect::>() + }; + + match plan.join_type() { + JoinType::Inner | JoinType::Right => { + map_fields(plan.left().schema(), JoinSide::Left) + .into_iter() + .chain(map_fields(plan.right().schema(), JoinSide::Right)) + .collect::>() + } + JoinType::RightSemi | JoinType::RightAnti => { + map_fields(plan.right().schema(), JoinSide::Right) + } + _ => unreachable!("unexpected join type: {}", plan.join_type()), + } +} + /// Define the Requirements Compatibility #[derive(Debug)] enum RequirementsCompatibility { diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index e636e93007a4..62f625119897 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -2864,13 +2864,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -2905,13 +2905,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -2967,10 +2967,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3003,10 +3003,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3061,13 +3061,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -3083,13 +3083,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -3143,10 +3143,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3160,10 +3160,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -4313,3 +4313,86 @@ physical_plan 04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)] 05)--------MemoryExec: partitions=1, partition_sizes=[1] 06)--------MemoryExec: partitions=1, partition_sizes=[1] + +# Test hash join sort push down +# Issue: https://github.com/apache/datafusion/issues/13559 +statement ok +CREATE TABLE test(a INT, b INT, c INT) + +statement ok +insert into test values (1,2,3), (4,5,6), (null, 7, 8), (8, null, 9), (9, 10, null) + +statement ok +set datafusion.execution.target_partitions = 2; + +query TT +explain select * from test where a in (select a from test where b > 3) order by c desc nulls first; +---- +logical_plan +01)Sort: test.c DESC NULLS FIRST +02)--LeftSemi Join: test.a = __correlated_sq_1.a +03)----TableScan: test projection=[a, b, c] +04)----SubqueryAlias: __correlated_sq_1 +05)------Projection: test.a +06)--------Filter: test.b > Int32(3) +07)----------TableScan: test projection=[a, b] +physical_plan +01)SortPreservingMergeExec: [c@2 DESC] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(a@0, a@0)] +04)------CoalesceBatchesExec: target_batch_size=3 +05)--------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +06)----------CoalesceBatchesExec: target_batch_size=3 +07)------------FilterExec: b@1 > 3, projection=[a@0] +08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +09)----------------MemoryExec: partitions=1, partition_sizes=[1] +10)------SortExec: expr=[c@2 DESC], preserve_partitioning=[true] +11)--------CoalesceBatchesExec: target_batch_size=3 +12)----------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +13)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +14)--------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select * from test where a in (select a from test where b > 3) order by c desc nulls last; +---- +logical_plan +01)Sort: test.c DESC NULLS LAST +02)--LeftSemi Join: test.a = __correlated_sq_1.a +03)----TableScan: test projection=[a, b, c] +04)----SubqueryAlias: __correlated_sq_1 +05)------Projection: test.a +06)--------Filter: test.b > Int32(3) +07)----------TableScan: test projection=[a, b] +physical_plan +01)SortPreservingMergeExec: [c@2 DESC NULLS LAST] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(a@0, a@0)] +04)------CoalesceBatchesExec: target_batch_size=3 +05)--------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +06)----------CoalesceBatchesExec: target_batch_size=3 +07)------------FilterExec: b@1 > 3, projection=[a@0] +08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +09)----------------MemoryExec: partitions=1, partition_sizes=[1] +10)------SortExec: expr=[c@2 DESC NULLS LAST], preserve_partitioning=[true] +11)--------CoalesceBatchesExec: target_batch_size=3 +12)----------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +13)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +14)--------------MemoryExec: partitions=1, partition_sizes=[1] + +query III +select * from test where a in (select a from test where b > 3) order by c desc nulls first; +---- +9 10 NULL +4 5 6 + +query III +select * from test where a in (select a from test where b > 3) order by c desc nulls last; +---- +4 5 6 +9 10 NULL + +statement ok +DROP TABLE test + +statement ok +set datafusion.execution.target_partitions = 1; From ec5e038036c00025905d9383faa86af018cf73ea Mon Sep 17 00:00:00 2001 From: Zhang Li Date: Tue, 10 Dec 2024 05:49:19 +0800 Subject: [PATCH 22/35] Improve substr() performance by avoiding using owned string (#13688) Co-authored-by: zhangli20 --- datafusion/functions/src/unicode/substr.rs | 77 +++++++++++----------- 1 file changed, 40 insertions(+), 37 deletions(-) diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 141984cf2674..687f77dbef5b 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -21,8 +21,8 @@ use std::sync::{Arc, OnceLock}; use crate::strings::{make_and_append_view, StringArrayType}; use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ - Array, ArrayIter, ArrayRef, AsArray, GenericStringArray, Int64Array, OffsetSizeTrait, - StringViewArray, + Array, ArrayIter, ArrayRef, AsArray, GenericStringBuilder, Int64Array, + OffsetSizeTrait, StringViewArray, }; use arrow::datatypes::DataType; use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; @@ -448,10 +448,9 @@ where match args.len() { 1 => { let iter = ArrayIter::new(string_array); - - let result = iter - .zip(start_array.iter()) - .map(|(string, start)| match (string, start) { + let mut result_builder = GenericStringBuilder::::new(); + for (string, start) in iter.zip(start_array.iter()) { + match (string, start) { (Some(string), Some(start)) => { let (start, end) = get_true_start_end( string, @@ -460,47 +459,51 @@ where enable_ascii_fast_path, ); // start, end is byte-based let substr = &string[start..end]; - Some(substr.to_string()) + result_builder.append_value(substr); } - _ => None, - }) - .collect::>(); - Ok(Arc::new(result) as ArrayRef) + _ => { + result_builder.append_null(); + } + } + } + Ok(Arc::new(result_builder.finish()) as ArrayRef) } 2 => { let iter = ArrayIter::new(string_array); let count_array = count_array_opt.unwrap(); + let mut result_builder = GenericStringBuilder::::new(); - let result = iter - .zip(start_array.iter()) - .zip(count_array.iter()) - .map(|((string, start), count)| { - match (string, start, count) { - (Some(string), Some(start), Some(count)) => { - if count < 0 { - exec_err!( + for ((string, start), count) in + iter.zip(start_array.iter()).zip(count_array.iter()) + { + match (string, start, count) { + (Some(string), Some(start), Some(count)) => { + if count < 0 { + return exec_err!( "negative substring length not allowed: substr(, {start}, {count})" - ) - } else { - if start == i64::MIN { - return exec_err!("negative overflow when calculating skip value"); - } - let (start, end) = get_true_start_end( - string, - start, - Some(count as u64), - enable_ascii_fast_path, - ); // start, end is byte-based - let substr = &string[start..end]; - Ok(Some(substr.to_string())) + ); + } else { + if start == i64::MIN { + return exec_err!( + "negative overflow when calculating skip value" + ); } + let (start, end) = get_true_start_end( + string, + start, + Some(count as u64), + enable_ascii_fast_path, + ); // start, end is byte-based + let substr = &string[start..end]; + result_builder.append_value(substr); } - _ => Ok(None), } - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) + _ => { + result_builder.append_null(); + } + } + } + Ok(Arc::new(result_builder.finish()) as ArrayRef) } other => { exec_err!("substr was called with {other} arguments. It requires 2 or 3.") From 412d3f6bb0cf77b520d7a0a730db88324178467c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 9 Dec 2024 14:49:58 -0700 Subject: [PATCH 23/35] reinstate down_cast_any_ref (#13705) --- .../physical-expr-common/src/physical_expr.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 93bdcdef8ea0..c2e892d63da0 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -214,6 +214,21 @@ pub fn with_new_children_if_necessary( } } +#[deprecated(since = "44.0.0")] +pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { + if any.is::>() { + any.downcast_ref::>() + .unwrap() + .as_any() + } else if any.is::>() { + any.downcast_ref::>() + .unwrap() + .as_any() + } else { + any + } +} + /// Returns [`Display`] able a list of [`PhysicalExpr`] /// /// Example output: `[a + 1, b]` From dc17dd61186cd90b169cecc1669a14e451e35ad1 Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Tue, 10 Dec 2024 09:33:25 +0700 Subject: [PATCH 24/35] Optimize performance of `character_length` function (#13696) * Optimize performance of function Signed-off-by: Tai Le Manh * Add pre-check array is null * Fix clippy warnings --------- Signed-off-by: Tai Le Manh --- .../functions/src/unicode/character_length.rs | 57 +++++++++++++------ 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index 822bdca9aca8..ad51a8ef72fb 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -18,7 +18,7 @@ use crate::strings::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ - Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveArray, + Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveBuilder, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; use datafusion_common::Result; @@ -136,31 +136,52 @@ fn character_length(args: &[ArrayRef]) -> Result { } } -fn character_length_general<'a, T: ArrowPrimitiveType, V: StringArrayType<'a>>( - array: V, -) -> Result +fn character_length_general<'a, T, V>(array: V) -> Result where + T: ArrowPrimitiveType, T::Native: OffsetSizeTrait, + V: StringArrayType<'a>, { + let mut builder = PrimitiveBuilder::::with_capacity(array.len()); + // String characters are variable length encoded in UTF-8, counting the // number of chars requires expensive decoding, however checking if the // string is ASCII only is relatively cheap. // If strings are ASCII only, count bytes instead. let is_array_ascii_only = array.is_ascii(); - let iter = array.iter(); - let result = iter - .map(|string| { - string.map(|string: &str| { - if is_array_ascii_only { - T::Native::usize_as(string.len()) - } else { - T::Native::usize_as(string.chars().count()) - } - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) + if array.null_count() == 0 { + if is_array_ascii_only { + for i in 0..array.len() { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.len())); + } + } else { + for i in 0..array.len() { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.chars().count())); + } + } + } else if is_array_ascii_only { + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.len())); + } + } + } else { + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.chars().count())); + } + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)] From 9b57875eeb30ddd5bf6f3bdf72cffa505fc2ba87 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 10 Dec 2024 02:21:58 -0800 Subject: [PATCH 25/35] Update prost-build requirement from =0.13.3 to =0.13.4 (#13698) Updates the requirements on [prost-build](https://github.com/tokio-rs/prost) to permit the latest version. - [Release notes](https://github.com/tokio-rs/prost/releases) - [Changelog](https://github.com/tokio-rs/prost/blob/master/CHANGELOG.md) - [Commits](https://github.com/tokio-rs/prost/compare/v0.13.3...v0.13.4) --- updated-dependencies: - dependency-name: prost-build dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/proto-common/gen/Cargo.toml | 2 +- datafusion/proto/gen/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index da5bc6029ff9..21fc9eccb40c 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -35,4 +35,4 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic pbjson-build = "=0.7.0" -prost-build = "=0.13.3" +prost-build = "=0.13.4" diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index 297406becada..dda72d20a159 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -35,4 +35,4 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic pbjson-build = "=0.7.0" -prost-build = "=0.13.3" +prost-build = "=0.13.4" From 4a08545ee2cc4fc5dca33c45360ba03864494e55 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Tue, 10 Dec 2024 12:18:39 -0800 Subject: [PATCH 26/35] Minor: Output elapsed time for sql logic test (#13718) * Minor: Output elapsed time for sql logic test --- datafusion/sqllogictest/bin/sqllogictests.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 12c0e27ea911..176bd3229125 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -15,10 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::ffi::OsStr; -use std::fs; -use std::path::{Path, PathBuf}; - use clap::Parser; use datafusion_common::utils::get_available_parallelism; use datafusion_sqllogictest::{DataFusion, TestContext}; @@ -26,6 +22,9 @@ use futures::stream::StreamExt; use itertools::Itertools; use log::info; use sqllogictest::strict_column_validator; +use std::ffi::OsStr; +use std::fs; +use std::path::{Path, PathBuf}; use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; @@ -100,7 +99,8 @@ async fn run_tests() -> Result<()> { let errors: Vec<_> = futures::stream::iter(read_test_files(&options)?) .map(|test_file| { SpawnedTask::spawn(async move { - println!("Running {:?}", test_file.relative_path); + let file_path = test_file.relative_path.clone(); + let start = datafusion::common::instant::Instant::now(); if options.complete { run_complete_file(test_file).await?; } else if options.postgres_runner { @@ -108,6 +108,7 @@ async fn run_tests() -> Result<()> { } else { run_test_file(test_file).await?; } + println!("Executed {:?}. Took {:?}", file_path, start.elapsed()); Ok(()) as Result<()> }) .join() From d02d587b922028be7b5c3996fa5d85547d3d6f12 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Wed, 11 Dec 2024 09:14:44 +0800 Subject: [PATCH 27/35] refactor: simplify the `make_udf_function` macro (#13712) --- datafusion/functions/src/core/mod.rs | 22 +++++----- datafusion/functions/src/crypto/mod.rs | 12 +++--- datafusion/functions/src/datetime/mod.rs | 54 ++++++++---------------- datafusion/functions/src/encoding/mod.rs | 4 +- datafusion/functions/src/macros.rs | 25 ++++++----- datafusion/functions/src/math/mod.rs | 53 +++++++---------------- datafusion/functions/src/regex/mod.rs | 12 ++---- datafusion/functions/src/string/mod.rs | 44 +++++++++---------- datafusion/functions/src/unicode/mod.rs | 28 ++++++------ 9 files changed, 101 insertions(+), 153 deletions(-) diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 24d26c539539..bd8305cd56d8 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -35,17 +35,17 @@ pub mod r#struct; pub mod version; // create UDFs -make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast); -make_udf_function!(nullif::NullIfFunc, NULLIF, nullif); -make_udf_function!(nvl::NVLFunc, NVL, nvl); -make_udf_function!(nvl2::NVL2Func, NVL2, nvl2); -make_udf_function!(arrowtypeof::ArrowTypeOfFunc, ARROWTYPEOF, arrow_typeof); -make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); -make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); -make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); -make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce); -make_udf_function!(greatest::GreatestFunc, GREATEST, greatest); -make_udf_function!(version::VersionFunc, VERSION, version); +make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast); +make_udf_function!(nullif::NullIfFunc, nullif); +make_udf_function!(nvl::NVLFunc, nvl); +make_udf_function!(nvl2::NVL2Func, nvl2); +make_udf_function!(arrowtypeof::ArrowTypeOfFunc, arrow_typeof); +make_udf_function!(r#struct::StructFunc, r#struct); +make_udf_function!(named_struct::NamedStructFunc, named_struct); +make_udf_function!(getfield::GetFieldFunc, get_field); +make_udf_function!(coalesce::CoalesceFunc, coalesce); +make_udf_function!(greatest::GreatestFunc, greatest); +make_udf_function!(version::VersionFunc, version); pub mod expr_fn { use datafusion_expr::{Expr, Literal}; diff --git a/datafusion/functions/src/crypto/mod.rs b/datafusion/functions/src/crypto/mod.rs index 46177fc22b60..62ea3c2e2737 100644 --- a/datafusion/functions/src/crypto/mod.rs +++ b/datafusion/functions/src/crypto/mod.rs @@ -27,12 +27,12 @@ pub mod sha224; pub mod sha256; pub mod sha384; pub mod sha512; -make_udf_function!(digest::DigestFunc, DIGEST, digest); -make_udf_function!(md5::Md5Func, MD5, md5); -make_udf_function!(sha224::SHA224Func, SHA224, sha224); -make_udf_function!(sha256::SHA256Func, SHA256, sha256); -make_udf_function!(sha384::SHA384Func, SHA384, sha384); -make_udf_function!(sha512::SHA512Func, SHA512, sha512); +make_udf_function!(digest::DigestFunc, digest); +make_udf_function!(md5::Md5Func, md5); +make_udf_function!(sha224::SHA224Func, sha224); +make_udf_function!(sha256::SHA256Func, sha256); +make_udf_function!(sha384::SHA384Func, sha384); +make_udf_function!(sha512::SHA512Func, sha512); pub mod expr_fn { export_functions!(( diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index db4e365267dd..96ca63010ee4 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -37,43 +37,23 @@ pub mod to_timestamp; pub mod to_unixtime; // create UDFs -make_udf_function!(current_date::CurrentDateFunc, CURRENT_DATE, current_date); -make_udf_function!(current_time::CurrentTimeFunc, CURRENT_TIME, current_time); -make_udf_function!(date_bin::DateBinFunc, DATE_BIN, date_bin); -make_udf_function!(date_part::DatePartFunc, DATE_PART, date_part); -make_udf_function!(date_trunc::DateTruncFunc, DATE_TRUNC, date_trunc); -make_udf_function!(make_date::MakeDateFunc, MAKE_DATE, make_date); -make_udf_function!( - from_unixtime::FromUnixtimeFunc, - FROM_UNIXTIME, - from_unixtime -); -make_udf_function!(now::NowFunc, NOW, now); -make_udf_function!(to_char::ToCharFunc, TO_CHAR, to_char); -make_udf_function!(to_date::ToDateFunc, TO_DATE, to_date); -make_udf_function!(to_local_time::ToLocalTimeFunc, TO_LOCAL_TIME, to_local_time); -make_udf_function!(to_unixtime::ToUnixtimeFunc, TO_UNIXTIME, to_unixtime); -make_udf_function!(to_timestamp::ToTimestampFunc, TO_TIMESTAMP, to_timestamp); -make_udf_function!( - to_timestamp::ToTimestampSecondsFunc, - TO_TIMESTAMP_SECONDS, - to_timestamp_seconds -); -make_udf_function!( - to_timestamp::ToTimestampMillisFunc, - TO_TIMESTAMP_MILLIS, - to_timestamp_millis -); -make_udf_function!( - to_timestamp::ToTimestampMicrosFunc, - TO_TIMESTAMP_MICROS, - to_timestamp_micros -); -make_udf_function!( - to_timestamp::ToTimestampNanosFunc, - TO_TIMESTAMP_NANOS, - to_timestamp_nanos -); +make_udf_function!(current_date::CurrentDateFunc, current_date); +make_udf_function!(current_time::CurrentTimeFunc, current_time); +make_udf_function!(date_bin::DateBinFunc, date_bin); +make_udf_function!(date_part::DatePartFunc, date_part); +make_udf_function!(date_trunc::DateTruncFunc, date_trunc); +make_udf_function!(make_date::MakeDateFunc, make_date); +make_udf_function!(from_unixtime::FromUnixtimeFunc, from_unixtime); +make_udf_function!(now::NowFunc, now); +make_udf_function!(to_char::ToCharFunc, to_char); +make_udf_function!(to_date::ToDateFunc, to_date); +make_udf_function!(to_local_time::ToLocalTimeFunc, to_local_time); +make_udf_function!(to_unixtime::ToUnixtimeFunc, to_unixtime); +make_udf_function!(to_timestamp::ToTimestampFunc, to_timestamp); +make_udf_function!(to_timestamp::ToTimestampSecondsFunc, to_timestamp_seconds); +make_udf_function!(to_timestamp::ToTimestampMillisFunc, to_timestamp_millis); +make_udf_function!(to_timestamp::ToTimestampMicrosFunc, to_timestamp_micros); +make_udf_function!(to_timestamp::ToTimestampNanosFunc, to_timestamp_nanos); // we cannot currently use the export_functions macro since it doesn't handle // functions with varargs currently diff --git a/datafusion/functions/src/encoding/mod.rs b/datafusion/functions/src/encoding/mod.rs index 48171370ad58..b0ddbd368a6b 100644 --- a/datafusion/functions/src/encoding/mod.rs +++ b/datafusion/functions/src/encoding/mod.rs @@ -21,8 +21,8 @@ use std::sync::Arc; pub mod inner; // create `encode` and `decode` UDFs -make_udf_function!(inner::EncodeFunc, ENCODE, encode); -make_udf_function!(inner::DecodeFunc, DECODE, decode); +make_udf_function!(inner::EncodeFunc, encode); +make_udf_function!(inner::DecodeFunc, decode); // Export the functions out of this package, both as expr_fn as well as a list of functions pub mod expr_fn { diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index bedec9bb2e6f..82308601490c 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -65,24 +65,23 @@ macro_rules! export_functions { }; } -/// Creates a singleton `ScalarUDF` of the `$UDF` function named `$GNAME` and a -/// function named `$NAME` which returns that singleton. +/// Creates a singleton `ScalarUDF` of the `$UDF` function and a function +/// named `$NAME` which returns that singleton. /// /// This is used to ensure creating the list of `ScalarUDF` only happens once. macro_rules! make_udf_function { - ($UDF:ty, $GNAME:ident, $NAME:ident) => { - #[doc = "Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation "] - #[doc = stringify!($UDF)] + ($UDF:ty, $NAME:ident) => { + #[doc = concat!("Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation of ", stringify!($NAME))] pub fn $NAME() -> std::sync::Arc { // Singleton instance of the function - static $GNAME: std::sync::LazyLock< + static INSTANCE: std::sync::LazyLock< std::sync::Arc, > = std::sync::LazyLock::new(|| { std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( <$UDF>::new(), )) }); - std::sync::Arc::clone(&$GNAME) + std::sync::Arc::clone(&INSTANCE) } }; } @@ -134,13 +133,13 @@ macro_rules! downcast_arg { /// applies a unary floating function to the argument, and returns a value of the same type. /// /// $UDF: the name of the UDF struct that implements `ScalarUDFImpl` -/// $GNAME: a singleton instance of the UDF /// $NAME: the name of the function /// $UNARY_FUNC: the unary function to apply to the argument /// $OUTPUT_ORDERING: the output ordering calculation method of the function +/// $GET_DOC: the function to get the documentation of the UDF macro_rules! make_math_unary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr) => { - make_udf_function!($NAME::$UDF, $GNAME, $NAME); + ($UDF:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr) => { + make_udf_function!($NAME::$UDF, $NAME); mod $NAME { use std::any::Any; @@ -248,13 +247,13 @@ macro_rules! make_math_unary_udf { /// applies a binary floating function to the argument, and returns a value of the same type. /// /// $UDF: the name of the UDF struct that implements `ScalarUDFImpl` -/// $GNAME: a singleton instance of the UDF /// $NAME: the name of the function /// $BINARY_FUNC: the binary function to apply to the argument /// $OUTPUT_ORDERING: the output ordering calculation method of the function +/// $GET_DOC: the function to get the documentation of the UDF macro_rules! make_math_binary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $BINARY_FUNC:ident, $OUTPUT_ORDERING:expr, $GET_DOC:expr) => { - make_udf_function!($NAME::$UDF, $GNAME, $NAME); + ($UDF:ident, $NAME:ident, $BINARY_FUNC:ident, $OUTPUT_ORDERING:expr, $GET_DOC:expr) => { + make_udf_function!($NAME::$UDF, $NAME); mod $NAME { use std::any::Any; diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 1452bfdee5a0..4eb337a30110 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -40,10 +40,9 @@ pub mod signum; pub mod trunc; // Create UDFs -make_udf_function!(abs::AbsFunc, ABS, abs); +make_udf_function!(abs::AbsFunc, abs); make_math_unary_udf!( AcosFunc, - ACOS, acos, acos, super::acos_order, @@ -52,7 +51,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AcoshFunc, - ACOSH, acosh, acosh, super::acosh_order, @@ -61,7 +59,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AsinFunc, - ASIN, asin, asin, super::asin_order, @@ -70,7 +67,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AsinhFunc, - ASINH, asinh, asinh, super::asinh_order, @@ -79,7 +75,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AtanFunc, - ATAN, atan, atan, super::atan_order, @@ -88,7 +83,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AtanhFunc, - ATANH, atanh, atanh, super::atanh_order, @@ -97,7 +91,6 @@ make_math_unary_udf!( ); make_math_binary_udf!( Atan2, - ATAN2, atan2, atan2, super::atan2_order, @@ -105,7 +98,6 @@ make_math_binary_udf!( ); make_math_unary_udf!( CbrtFunc, - CBRT, cbrt, cbrt, super::cbrt_order, @@ -114,7 +106,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( CeilFunc, - CEIL, ceil, ceil, super::ceil_order, @@ -123,7 +114,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( CosFunc, - COS, cos, cos, super::cos_order, @@ -132,17 +122,15 @@ make_math_unary_udf!( ); make_math_unary_udf!( CoshFunc, - COSH, cosh, cosh, super::cosh_order, super::bounds::cosh_bounds, super::get_cosh_doc ); -make_udf_function!(cot::CotFunc, COT, cot); +make_udf_function!(cot::CotFunc, cot); make_math_unary_udf!( DegreesFunc, - DEGREES, degrees, to_degrees, super::degrees_order, @@ -151,31 +139,28 @@ make_math_unary_udf!( ); make_math_unary_udf!( ExpFunc, - EXP, exp, exp, super::exp_order, super::bounds::exp_bounds, super::get_exp_doc ); -make_udf_function!(factorial::FactorialFunc, FACTORIAL, factorial); +make_udf_function!(factorial::FactorialFunc, factorial); make_math_unary_udf!( FloorFunc, - FLOOR, floor, floor, super::floor_order, super::bounds::unbounded_bounds, super::get_floor_doc ); -make_udf_function!(log::LogFunc, LOG, log); -make_udf_function!(gcd::GcdFunc, GCD, gcd); -make_udf_function!(nans::IsNanFunc, ISNAN, isnan); -make_udf_function!(iszero::IsZeroFunc, ISZERO, iszero); -make_udf_function!(lcm::LcmFunc, LCM, lcm); +make_udf_function!(log::LogFunc, log); +make_udf_function!(gcd::GcdFunc, gcd); +make_udf_function!(nans::IsNanFunc, isnan); +make_udf_function!(iszero::IsZeroFunc, iszero); +make_udf_function!(lcm::LcmFunc, lcm); make_math_unary_udf!( LnFunc, - LN, ln, ln, super::ln_order, @@ -184,7 +169,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( Log2Func, - LOG2, log2, log2, super::log2_order, @@ -193,31 +177,28 @@ make_math_unary_udf!( ); make_math_unary_udf!( Log10Func, - LOG10, log10, log10, super::log10_order, super::bounds::unbounded_bounds, super::get_log10_doc ); -make_udf_function!(nanvl::NanvlFunc, NANVL, nanvl); -make_udf_function!(pi::PiFunc, PI, pi); -make_udf_function!(power::PowerFunc, POWER, power); +make_udf_function!(nanvl::NanvlFunc, nanvl); +make_udf_function!(pi::PiFunc, pi); +make_udf_function!(power::PowerFunc, power); make_math_unary_udf!( RadiansFunc, - RADIANS, radians, to_radians, super::radians_order, super::bounds::radians_bounds, super::get_radians_doc ); -make_udf_function!(random::RandomFunc, RANDOM, random); -make_udf_function!(round::RoundFunc, ROUND, round); -make_udf_function!(signum::SignumFunc, SIGNUM, signum); +make_udf_function!(random::RandomFunc, random); +make_udf_function!(round::RoundFunc, round); +make_udf_function!(signum::SignumFunc, signum); make_math_unary_udf!( SinFunc, - SIN, sin, sin, super::sin_order, @@ -226,7 +207,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( SinhFunc, - SINH, sinh, sinh, super::sinh_order, @@ -235,7 +215,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( SqrtFunc, - SQRT, sqrt, sqrt, super::sqrt_order, @@ -244,7 +223,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( TanFunc, - TAN, tan, tan, super::tan_order, @@ -253,14 +231,13 @@ make_math_unary_udf!( ); make_math_unary_udf!( TanhFunc, - TANH, tanh, tanh, super::tanh_order, super::bounds::tanh_bounds, super::get_tanh_doc ); -make_udf_function!(trunc::TruncFunc, TRUNC, trunc); +make_udf_function!(trunc::TruncFunc, trunc); pub mod expr_fn { export_functions!( diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 803f51e915a9..13fbc049af58 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -25,14 +25,10 @@ pub mod regexpmatch; pub mod regexpreplace; // create UDFs -make_udf_function!(regexpcount::RegexpCountFunc, REGEXP_COUNT, regexp_count); -make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match); -make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like); -make_udf_function!( - regexpreplace::RegexpReplaceFunc, - REGEXP_REPLACE, - regexp_replace -); +make_udf_function!(regexpcount::RegexpCountFunc, regexp_count); +make_udf_function!(regexpmatch::RegexpMatchFunc, regexp_match); +make_udf_function!(regexplike::RegexpLikeFunc, regexp_like); +make_udf_function!(regexpreplace::RegexpReplaceFunc, regexp_replace); pub mod expr_fn { use datafusion_expr::Expr; diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 622802f0142b..f156f070d960 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -45,28 +45,28 @@ pub mod to_hex; pub mod upper; pub mod uuid; // create UDFs -make_udf_function!(ascii::AsciiFunc, ASCII, ascii); -make_udf_function!(bit_length::BitLengthFunc, BIT_LENGTH, bit_length); -make_udf_function!(btrim::BTrimFunc, BTRIM, btrim); -make_udf_function!(chr::ChrFunc, CHR, chr); -make_udf_function!(concat::ConcatFunc, CONCAT, concat); -make_udf_function!(concat_ws::ConcatWsFunc, CONCAT_WS, concat_ws); -make_udf_function!(ends_with::EndsWithFunc, ENDS_WITH, ends_with); -make_udf_function!(initcap::InitcapFunc, INITCAP, initcap); -make_udf_function!(levenshtein::LevenshteinFunc, LEVENSHTEIN, levenshtein); -make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim); -make_udf_function!(lower::LowerFunc, LOWER, lower); -make_udf_function!(octet_length::OctetLengthFunc, OCTET_LENGTH, octet_length); -make_udf_function!(overlay::OverlayFunc, OVERLAY, overlay); -make_udf_function!(repeat::RepeatFunc, REPEAT, repeat); -make_udf_function!(replace::ReplaceFunc, REPLACE, replace); -make_udf_function!(rtrim::RtrimFunc, RTRIM, rtrim); -make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with); -make_udf_function!(split_part::SplitPartFunc, SPLIT_PART, split_part); -make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); -make_udf_function!(upper::UpperFunc, UPPER, upper); -make_udf_function!(uuid::UuidFunc, UUID, uuid); -make_udf_function!(contains::ContainsFunc, CONTAINS, contains); +make_udf_function!(ascii::AsciiFunc, ascii); +make_udf_function!(bit_length::BitLengthFunc, bit_length); +make_udf_function!(btrim::BTrimFunc, btrim); +make_udf_function!(chr::ChrFunc, chr); +make_udf_function!(concat::ConcatFunc, concat); +make_udf_function!(concat_ws::ConcatWsFunc, concat_ws); +make_udf_function!(ends_with::EndsWithFunc, ends_with); +make_udf_function!(initcap::InitcapFunc, initcap); +make_udf_function!(levenshtein::LevenshteinFunc, levenshtein); +make_udf_function!(ltrim::LtrimFunc, ltrim); +make_udf_function!(lower::LowerFunc, lower); +make_udf_function!(octet_length::OctetLengthFunc, octet_length); +make_udf_function!(overlay::OverlayFunc, overlay); +make_udf_function!(repeat::RepeatFunc, repeat); +make_udf_function!(replace::ReplaceFunc, replace); +make_udf_function!(rtrim::RtrimFunc, rtrim); +make_udf_function!(starts_with::StartsWithFunc, starts_with); +make_udf_function!(split_part::SplitPartFunc, split_part); +make_udf_function!(to_hex::ToHexFunc, to_hex); +make_udf_function!(upper::UpperFunc, upper); +make_udf_function!(uuid::UuidFunc, uuid); +make_udf_function!(contains::ContainsFunc, contains); pub mod expr_fn { use datafusion_expr::Expr; diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index 40915bc9efde..f31ece9196d8 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -34,22 +34,18 @@ pub mod substrindex; pub mod translate; // create UDFs -make_udf_function!( - character_length::CharacterLengthFunc, - CHARACTER_LENGTH, - character_length -); -make_udf_function!(find_in_set::FindInSetFunc, FIND_IN_SET, find_in_set); -make_udf_function!(left::LeftFunc, LEFT, left); -make_udf_function!(lpad::LPadFunc, LPAD, lpad); -make_udf_function!(right::RightFunc, RIGHT, right); -make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); -make_udf_function!(rpad::RPadFunc, RPAD, rpad); -make_udf_function!(strpos::StrposFunc, STRPOS, strpos); -make_udf_function!(substr::SubstrFunc, SUBSTR, substr); -make_udf_function!(substr::SubstrFunc, SUBSTRING, substring); -make_udf_function!(substrindex::SubstrIndexFunc, SUBSTR_INDEX, substr_index); -make_udf_function!(translate::TranslateFunc, TRANSLATE, translate); +make_udf_function!(character_length::CharacterLengthFunc, character_length); +make_udf_function!(find_in_set::FindInSetFunc, find_in_set); +make_udf_function!(left::LeftFunc, left); +make_udf_function!(lpad::LPadFunc, lpad); +make_udf_function!(right::RightFunc, right); +make_udf_function!(reverse::ReverseFunc, reverse); +make_udf_function!(rpad::RPadFunc, rpad); +make_udf_function!(strpos::StrposFunc, strpos); +make_udf_function!(substr::SubstrFunc, substr); +make_udf_function!(substr::SubstrFunc, substring); +make_udf_function!(substrindex::SubstrIndexFunc, substr_index); +make_udf_function!(translate::TranslateFunc, translate); pub mod expr_fn { use datafusion_expr::Expr; From 0e413417d6bdf70ff831c31122911b02a093a675 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 11 Dec 2024 20:02:10 +0800 Subject: [PATCH 28/35] refactor: replace `Vec` with `IndexMap` for expression mappings in `ProjectionMapping` and `EquivalenceGroup` (#13675) * refactor: replace Vec with IndexMap for expression mappings in ProjectionMapping and EquivalenceGroup * chore * chore: Fix CI * chore: comment * chore: simplify --- .../physical-expr/src/equivalence/class.rs | 34 +++++++------------ 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index d06a495d970a..cc26d12fb029 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -17,8 +17,8 @@ use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping}; use crate::{ - expressions::Column, physical_exprs_contains, LexOrdering, LexRequirement, - PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, + expressions::Column, LexOrdering, LexRequirement, PhysicalExpr, PhysicalExprRef, + PhysicalSortExpr, PhysicalSortRequirement, }; use std::fmt::Display; use std::sync::Arc; @@ -27,7 +27,7 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::JoinType; use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; -use indexmap::IndexSet; +use indexmap::{IndexMap, IndexSet}; /// A structure representing a expression known to be constant in a physical execution plan. /// @@ -546,28 +546,20 @@ impl EquivalenceGroup { .collect::>(); (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) }); - // TODO: Convert the algorithm below to a version that uses `HashMap`. - // once `Arc` can be stored in `HashMap`. - // See issue: https://github.com/apache/datafusion/issues/8027 - let mut new_classes = vec![]; - for (source, target) in mapping.iter() { - if new_classes.is_empty() { - new_classes.push((source, vec![Arc::clone(target)])); - } - if let Some((_, values)) = - new_classes.iter_mut().find(|(key, _)| *key == source) - { - if !physical_exprs_contains(values, target) { - values.push(Arc::clone(target)); - } - } - } + // the key is the source expression and the value is the EquivalenceClass that contains the target expression of the source expression. + let mut new_classes: IndexMap, EquivalenceClass> = + IndexMap::new(); + mapping.iter().for_each(|(source, target)| { + new_classes + .entry(Arc::clone(source)) + .or_insert_with(EquivalenceClass::new_empty) + .push(Arc::clone(target)); + }); // Only add equivalence classes with at least two members as singleton // equivalence classes are meaningless. let new_classes = new_classes .into_iter() - .filter_map(|(_, values)| (values.len() > 1).then_some(values)) - .map(EquivalenceClass::new); + .filter_map(|(_, cls)| (cls.len() > 1).then_some(cls)); let classes = projected_classes.chain(new_classes).collect(); Self::new(classes) From a8fc264a8b1705a01053519e89ae70bbf00521e3 Mon Sep 17 00:00:00 2001 From: Eason <30045503+Eason0729@users.noreply.github.com> Date: Wed, 11 Dec 2024 20:21:17 +0800 Subject: [PATCH 29/35] Handle alias when parsing sql(parse_sql_expr) (#12939) * fix: Fix parse_sql_expr not handling alias * cargo fmt * fix parse_sql_expr example(remove alias) * add testing * add SUM udaf to TestContextProvider and modify test_sql_to_expr_with_alias for function * revert change on example `parse_sql_expr` --- .../examples/parse_sql_expr.rs | 10 ++-- .../core/src/execution/session_state.rs | 21 +++++-- datafusion/sql/src/expr/mod.rs | 60 +++++++++++++++++-- datafusion/sql/src/parser.rs | 9 +-- 4 files changed, 82 insertions(+), 18 deletions(-) diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index e23e5accae39..d8f0778e19e3 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -121,11 +121,11 @@ async fn query_parquet_demo() -> Result<()> { assert_batches_eq!( &[ - "+------------+----------------------+", - "| double_col | sum(?table?.int_col) |", - "+------------+----------------------+", - "| 10.1 | 4 |", - "+------------+----------------------+", + "+------------+-------------+", + "| double_col | sum_int_col |", + "+------------+-------------+", + "| 10.1 | 4 |", + "+------------+-------------+", ], &result ); diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 4ccad5ffd323..cef5d4c1ee2a 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -68,7 +68,7 @@ use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, Sq use itertools::Itertools; use log::{debug, info}; use object_store::ObjectStore; -use sqlparser::ast::Expr as SQLExpr; +use sqlparser::ast::{Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias}; use sqlparser::dialect::dialect_from_str; use std::any::Any; use std::collections::hash_map::Entry; @@ -500,11 +500,22 @@ impl SessionState { sql: &str, dialect: &str, ) -> datafusion_common::Result { + self.sql_to_expr_with_alias(sql, dialect).map(|x| x.expr) + } + + /// parse a sql string into a sqlparser-rs AST [`SQLExprWithAlias`]. + /// + /// See [`Self::create_logical_expr`] for parsing sql to [`Expr`]. + pub fn sql_to_expr_with_alias( + &self, + sql: &str, + dialect: &str, + ) -> datafusion_common::Result { let dialect = dialect_from_str(dialect).ok_or_else(|| { plan_datafusion_err!( "Unsupported SQL dialect: {dialect}. Available dialects: \ - Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ - MsSQL, ClickHouse, BigQuery, Ansi." + Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ + MsSQL, ClickHouse, BigQuery, Ansi." ) })?; @@ -603,7 +614,7 @@ impl SessionState { ) -> datafusion_common::Result { let dialect = self.config.options().sql_parser.dialect.as_str(); - let sql_expr = self.sql_to_expr(sql, dialect)?; + let sql_expr = self.sql_to_expr_with_alias(sql, dialect)?; let provider = SessionContextProvider { state: self, @@ -611,7 +622,7 @@ impl SessionState { }; let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); - query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new()) + query.sql_to_expr_with_alias(sql_expr, df_schema, &mut PlannerContext::new()) } /// Returns the [`Analyzer`] for this session diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 57ac96951f1f..e8ec8d7b7d1c 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -23,7 +23,8 @@ use datafusion_expr::planner::{ use recursive::recursive; use sqlparser::ast::{ BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DictionaryField, - Expr as SQLExpr, MapEntry, StructField, Subscript, TrimWhereField, Value, + Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, StructField, Subscript, + TrimWhereField, Value, }; use datafusion_common::{ @@ -50,6 +51,19 @@ mod unary_op; mod value; impl SqlToRel<'_, S> { + pub(crate) fn sql_expr_to_logical_expr_with_alias( + &self, + sql: SQLExprWithAlias, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let mut expr = + self.sql_expr_to_logical_expr(sql.expr, schema, planner_context)?; + if let Some(alias) = sql.alias { + expr = expr.alias(alias.value); + } + Ok(expr) + } pub(crate) fn sql_expr_to_logical_expr( &self, sql: SQLExpr, @@ -131,6 +145,20 @@ impl SqlToRel<'_, S> { ))) } + pub fn sql_to_expr_with_alias( + &self, + sql: SQLExprWithAlias, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let mut expr = + self.sql_expr_to_logical_expr_with_alias(sql, schema, planner_context)?; + expr = self.rewrite_partial_qualifier(expr, schema); + self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?; + let (expr, _) = expr.infer_placeholder_types(schema)?; + Ok(expr) + } + /// Generate a relational expression from a SQL expression pub fn sql_to_expr( &self, @@ -1091,8 +1119,11 @@ mod tests { None } - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None + fn get_aggregate_meta(&self, name: &str) -> Option> { + match name { + "sum" => Some(datafusion_functions_aggregate::sum::sum_udaf()), + _ => None, + } } fn get_variable_type(&self, _variable_names: &[String]) -> Option { @@ -1112,7 +1143,7 @@ mod tests { } fn udaf_names(&self) -> Vec { - Vec::new() + vec!["sum".to_string()] } fn udwf_names(&self) -> Vec { @@ -1167,4 +1198,25 @@ mod tests { test_stack_overflow!(2048); test_stack_overflow!(4096); test_stack_overflow!(8192); + #[test] + fn test_sql_to_expr_with_alias() { + let schema = DFSchema::empty(); + let mut planner_context = PlannerContext::default(); + + let expr_str = "SUM(int_col) as sum_int_col"; + + let dialect = GenericDialect {}; + let mut parser = Parser::new(&dialect).try_with_sql(expr_str).unwrap(); + // from sqlparser + let sql_expr = parser.parse_expr_with_alias().unwrap(); + + let context_provider = TestContextProvider::new(); + let sql_to_rel = SqlToRel::new(&context_provider); + + let expr = sql_to_rel + .sql_expr_to_logical_expr_with_alias(sql_expr, &schema, &mut planner_context) + .unwrap(); + + assert!(matches!(expr, Expr::Alias(_))); + } } diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index bd1ed3145ef5..efec6020641c 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -20,9 +20,10 @@ use std::collections::VecDeque; use std::fmt; +use sqlparser::ast::ExprWithAlias; use sqlparser::{ ast::{ - ColumnDef, ColumnOptionDef, Expr, ObjectName, OrderByExpr, Query, + ColumnDef, ColumnOptionDef, ObjectName, OrderByExpr, Query, Statement as SQLStatement, TableConstraint, Value, }, dialect::{keywords::Keyword, Dialect, GenericDialect}, @@ -328,7 +329,7 @@ impl<'a> DFParser<'a> { pub fn parse_sql_into_expr_with_dialect( sql: &str, dialect: &dyn Dialect, - ) -> Result { + ) -> Result { let mut parser = DFParser::new_with_dialect(sql, dialect)?; parser.parse_expr() } @@ -377,7 +378,7 @@ impl<'a> DFParser<'a> { } } - pub fn parse_expr(&mut self) -> Result { + pub fn parse_expr(&mut self) -> Result { if let Token::Word(w) = self.parser.peek_token().token { match w.keyword { Keyword::CREATE | Keyword::COPY | Keyword::EXPLAIN => { @@ -387,7 +388,7 @@ impl<'a> DFParser<'a> { } } - self.parser.parse_expr() + self.parser.parse_expr_with_alias() } /// Parse a SQL `COPY TO` statement From a505610e5986e674e97b4221e364591989d7a448 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 11 Dec 2024 07:27:46 -0500 Subject: [PATCH 30/35] Improve documentation for TableProvider (#13724) --- datafusion/catalog/src/table.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index b6752191d9a7..3c8960495588 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -33,7 +33,19 @@ use datafusion_expr::{ }; use datafusion_physical_plan::ExecutionPlan; -/// Source table +/// A named table which can be queried. +/// +/// Please see [`CatalogProvider`] for details of implementing a custom catalog. +/// +/// [`TableProvider`] represents a source of data which can provide data as +/// Apache Arrow `RecordBatch`es. Implementations of this trait provide +/// important information for planning such as: +/// +/// 1. [`Self::schema`]: The schema (columns and their types) of the table +/// 2. [`Self::supports_filters_pushdown`]: Should filters be pushed into this scan +/// 2. [`Self::scan`]: An [`ExecutionPlan`] that can read data +/// +/// [`CatalogProvider`]: super::CatalogProvider #[async_trait] pub trait TableProvider: Debug + Sync + Send { /// Returns the table provider as [`Any`](std::any::Any) so that it can be From 4fb668bcc0909b79abaff70cf919e46b900002c6 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 11 Dec 2024 14:29:23 +0100 Subject: [PATCH 31/35] Reveal implementing type and return type in simple UDF implementations (#13730) Debug trait is useful for understanding what something is and how it's configured, especially if the implementation is behind dyn trait. --- datafusion/expr/src/expr_fn.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 681eb3c0afd5..a44dd24039dc 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -416,9 +416,10 @@ pub struct SimpleScalarUDF { impl Debug for SimpleScalarUDF { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("ScalarUDF") + f.debug_struct("SimpleScalarUDF") .field("name", &self.name) .field("signature", &self.signature) + .field("return_type", &self.return_type) .field("fun", &"") .finish() } @@ -524,9 +525,10 @@ pub struct SimpleAggregateUDF { impl Debug for SimpleAggregateUDF { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("AggregateUDF") + f.debug_struct("SimpleAggregateUDF") .field("name", &self.name) .field("signature", &self.signature) + .field("return_type", &self.return_type) .field("fun", &"") .finish() } From 1ab089e5563a736f350b3fabea6a32f132621d7b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 11 Dec 2024 09:49:23 -0500 Subject: [PATCH 32/35] minor: Extract tests for `EXTRACT` AND `date_part` to their own file (#13731) --- datafusion/sqllogictest/test_files/expr.slt | 861 ----------------- .../test_files/expr/date_part.slt | 878 ++++++++++++++++++ 2 files changed, 878 insertions(+), 861 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/expr/date_part.slt diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 9b8dfc2186be..2306eda77d35 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -832,867 +832,6 @@ SELECT ---- 0 NULL 0 NULL -# test_extract_date_part - -query error -SELECT EXTRACT("'''year'''" FROM timestamp '2020-09-08T12:00:00+00:00') - -query error -SELECT EXTRACT("'year'" FROM timestamp '2020-09-08T12:00:00+00:00') - -query I -SELECT date_part('YEAR', CAST('2000-01-01' AS DATE)) ----- -2000 - -query I -SELECT EXTRACT(year FROM timestamp '2020-09-08T12:00:00+00:00') ----- -2020 - -query I -SELECT EXTRACT("year" FROM timestamp '2020-09-08T12:00:00+00:00') ----- -2020 - -query I -SELECT EXTRACT('year' FROM timestamp '2020-09-08T12:00:00+00:00') ----- -2020 - -query I -SELECT date_part('QUARTER', CAST('2000-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(quarter FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -3 - -query I -SELECT EXTRACT("quarter" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -3 - -query I -SELECT EXTRACT('quarter' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -3 - -query I -SELECT date_part('MONTH', CAST('2000-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(month FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -9 - -query I -SELECT EXTRACT("month" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -9 - -query I -SELECT EXTRACT('month' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -9 - -query I -SELECT date_part('WEEK', CAST('2003-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(WEEK FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -37 - -query I -SELECT EXTRACT("WEEK" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -37 - -query I -SELECT EXTRACT('WEEK' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -37 - -query I -SELECT date_part('DAY', CAST('2000-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(day FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -8 - -query I -SELECT EXTRACT("day" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -8 - -query I -SELECT EXTRACT('day' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -8 - -query I -SELECT date_part('DOY', CAST('2000-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(doy FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -252 - -query I -SELECT EXTRACT("doy" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -252 - -query I -SELECT EXTRACT('doy' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -252 - -query I -SELECT date_part('DOW', CAST('2000-01-01' AS DATE)) ----- -6 - -query I -SELECT EXTRACT(dow FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -2 - -query I -SELECT EXTRACT("dow" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -2 - -query I -SELECT EXTRACT('dow' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -2 - -query I -SELECT date_part('HOUR', CAST('2000-01-01' AS DATE)) ----- -0 - -query I -SELECT EXTRACT(hour FROM to_timestamp('2020-09-08T12:03:03+00:00')) ----- -12 - -query I -SELECT EXTRACT("hour" FROM to_timestamp('2020-09-08T12:03:03+00:00')) ----- -12 - -query I -SELECT EXTRACT('hour' FROM to_timestamp('2020-09-08T12:03:03+00:00')) ----- -12 - -query I -SELECT EXTRACT(minute FROM to_timestamp('2020-09-08T12:12:00+00:00')) ----- -12 - -query I -SELECT EXTRACT("minute" FROM to_timestamp('2020-09-08T12:12:00+00:00')) ----- -12 - -query I -SELECT EXTRACT('minute' FROM to_timestamp('2020-09-08T12:12:00+00:00')) ----- -12 - -query I -SELECT date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00')) ----- -12 - -# make sure the return type is integer -query T -SELECT arrow_typeof(date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00'))) ----- -Int32 - -query I -SELECT EXTRACT(second FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT EXTRACT(millisecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT EXTRACT(microsecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT EXTRACT(nanosecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') - -query I -SELECT EXTRACT("second" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT EXTRACT("millisecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT EXTRACT("microsecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT EXTRACT("nanosecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') - -query I -SELECT EXTRACT('second' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT EXTRACT('millisecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT EXTRACT('microsecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT EXTRACT('nanosecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') - - -# Keep precision when coercing Utf8 to Timestamp -query I -SELECT date_part('second', timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT date_part('millisecond', timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT date_part('microsecond', timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT date_part('nanosecond', timestamp '2020-09-08T12:00:12.12345678+00:00') - - -query I -SELECT date_part('second', '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT date_part('millisecond', '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT date_part('microsecond', '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00') - -# test_date_part_time - -## time32 seconds -query I -SELECT date_part('hour', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -23 - -query I -SELECT extract(hour from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -23 - -query I -SELECT date_part('minute', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -32 - -query I -SELECT extract(minute from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -32 - -query I -SELECT date_part('second', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50 - -query I -SELECT extract(second from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50 - -query I -SELECT date_part('millisecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50000 - -query I -SELECT extract(millisecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50000 - -query I -SELECT date_part('microsecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50000000 - -query I -SELECT extract(microsecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50000000 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT extract(nanosecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) - -query R -SELECT date_part('epoch', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -84770 - -query R -SELECT extract(epoch from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -84770 - -## time32 milliseconds -query I -SELECT date_part('hour', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -23 - -query I -SELECT extract(hour from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -23 - -query I -SELECT date_part('minute', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -32 - -query I -SELECT extract(minute from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -32 - -query I -SELECT date_part('second', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50 - -query I -SELECT extract(second from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50 - -query I -SELECT date_part('millisecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50123 - -query I -SELECT extract(millisecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50123 - -query I -SELECT date_part('microsecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50123000 - -query I -SELECT extract(microsecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50123000 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT extract(nanosecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) - -query R -SELECT date_part('epoch', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -84770.123 - -query R -SELECT extract(epoch from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -84770.123 - -## time64 microseconds -query I -SELECT date_part('hour', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -23 - -query I -SELECT extract(hour from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -23 - -query I -SELECT date_part('minute', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -32 - -query I -SELECT extract(minute from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -32 - -query I -SELECT date_part('second', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50 - -query I -SELECT extract(second from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50 - -query I -SELECT date_part('millisecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50123 - -query I -SELECT extract(millisecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50123 - -query I -SELECT date_part('microsecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50123456 - -query I -SELECT extract(microsecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT extract(nanosecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) - -query R -SELECT date_part('epoch', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -84770.123456 - -query R -SELECT extract(epoch from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -84770.123456 - -## time64 nanoseconds -query I -SELECT date_part('hour', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -23 - -query I -SELECT extract(hour from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -23 - -query I -SELECT date_part('minute', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -32 - -query I -SELECT extract(minute from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -32 - -query I -SELECT date_part('second', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50 - -query I -select extract(second from '2024-08-09T12:13:14') ----- -14 - -query I -select extract(seconds from '2024-08-09T12:13:14') ----- -14 - -query I -SELECT extract(second from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50 - -query I -SELECT date_part('millisecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123 - -query I -SELECT extract(millisecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123 - -# just some floating point stuff happening in the result here -query I -SELECT date_part('microsecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123456 - -query I -SELECT extract(microsecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123456 - -query I -SELECT extract(us from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT date_part('nanosecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) - -query R -SELECT date_part('epoch', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -84770.123456789 - -query R -SELECT extract(epoch from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -84770.123456789 - -# test_extract_epoch - -query R -SELECT extract(epoch from '1870-01-01T07:29:10.256'::timestamp) ----- --3155646649.744 - -query R -SELECT extract(epoch from '2000-01-01T00:00:00.000'::timestamp) ----- -946684800 - -query R -SELECT extract(epoch from to_timestamp('2000-01-01T00:00:00+00:00')) ----- -946684800 - -query R -SELECT extract(epoch from NULL::timestamp) ----- -NULL - -query R -SELECT extract(epoch from arrow_cast('1970-01-01', 'Date32')) ----- -0 - -query R -SELECT extract(epoch from arrow_cast('1970-01-02', 'Date32')) ----- -86400 - -query R -SELECT extract(epoch from arrow_cast('1970-01-11', 'Date32')) ----- -864000 - -query R -SELECT extract(epoch from arrow_cast('1969-12-31', 'Date32')) ----- --86400 - -query R -SELECT extract(epoch from arrow_cast('1970-01-01', 'Date64')) ----- -0 - -query R -SELECT extract(epoch from arrow_cast('1970-01-02', 'Date64')) ----- -86400 - -query R -SELECT extract(epoch from arrow_cast('1970-01-11', 'Date64')) ----- -864000 - -query R -SELECT extract(epoch from arrow_cast('1969-12-31', 'Date64')) ----- --86400 - -# test_extract_interval - -query I -SELECT extract(year from arrow_cast('10 years', 'Interval(YearMonth)')) ----- -10 - -query I -SELECT extract(month from arrow_cast('10 years', 'Interval(YearMonth)')) ----- -0 - -query I -SELECT extract(year from arrow_cast('10 months', 'Interval(YearMonth)')) ----- -0 - -query I -SELECT extract(month from arrow_cast('10 months', 'Interval(YearMonth)')) ----- -10 - -query I -SELECT extract(year from arrow_cast('20 months', 'Interval(YearMonth)')) ----- -1 - -query I -SELECT extract(month from arrow_cast('20 months', 'Interval(YearMonth)')) ----- -8 - -query error DataFusion error: Arrow error: Compute error: Year does not support: Interval\(DayTime\) -SELECT extract(year from arrow_cast('10 days', 'Interval(DayTime)')) - -query error DataFusion error: Arrow error: Compute error: Month does not support: Interval\(DayTime\) -SELECT extract(month from arrow_cast('10 days', 'Interval(DayTime)')) - -query I -SELECT extract(day from arrow_cast('10 days', 'Interval(DayTime)')) ----- -10 - -query I -SELECT extract(day from arrow_cast('14400 minutes', 'Interval(DayTime)')) ----- -0 - -query I -SELECT extract(minute from arrow_cast('14400 minutes', 'Interval(DayTime)')) ----- -14400 - -query I -SELECT extract(second from arrow_cast('5.1 seconds', 'Interval(DayTime)')) ----- -5 - -query I -SELECT extract(second from arrow_cast('14400 minutes', 'Interval(DayTime)')) ----- -864000 - -query I -SELECT extract(second from arrow_cast('2 months', 'Interval(MonthDayNano)')) ----- -0 - -query I -SELECT extract(second from arrow_cast('2 days', 'Interval(MonthDayNano)')) ----- -0 - -query I -SELECT extract(second from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) ----- -2 - -query I -SELECT extract(seconds from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) ----- -2 - -query R -SELECT extract(epoch from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) ----- -2 - -query I -SELECT extract(milliseconds from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) ----- -2000 - -query I -SELECT extract(second from arrow_cast('2030 milliseconds', 'Interval(MonthDayNano)')) ----- -2 - -query I -SELECT extract(second from arrow_cast(NULL, 'Interval(MonthDayNano)')) ----- -NULL - -statement ok -create table t (id int, i interval) as values - (0, interval '5 months 1 day 10 nanoseconds'), - (1, interval '1 year 3 months'), - (2, interval '3 days 2 milliseconds'), - (3, interval '2 seconds'), - (4, interval '8 months'), - (5, NULL); - -query III -select - id, - extract(second from i), - extract(month from i) -from t -order by id; ----- -0 0 5 -1 0 15 -2 0 0 -3 2 0 -4 0 8 -5 NULL NULL - -statement ok -drop table t; - -# test_extract_duration - -query I -SELECT extract(second from arrow_cast(2, 'Duration(Second)')) ----- -2 - -query I -SELECT extract(seconds from arrow_cast(2, 'Duration(Second)')) ----- -2 - -query R -SELECT extract(epoch from arrow_cast(2, 'Duration(Second)')) ----- -2 - -query I -SELECT extract(millisecond from arrow_cast(2, 'Duration(Second)')) ----- -2000 - -query I -SELECT extract(second from arrow_cast(2, 'Duration(Millisecond)')) ----- -0 - -query I -SELECT extract(second from arrow_cast(2002, 'Duration(Millisecond)')) ----- -2 - -query I -SELECT extract(millisecond from arrow_cast(2002, 'Duration(Millisecond)')) ----- -2002 - -query I -SELECT extract(day from arrow_cast(864000, 'Duration(Second)')) ----- -10 - -query error DataFusion error: Arrow error: Compute error: Month does not support: Duration\(Second\) -SELECT extract(month from arrow_cast(864000, 'Duration(Second)')) - -query error DataFusion error: Arrow error: Compute error: Year does not support: Duration\(Second\) -SELECT extract(year from arrow_cast(864000, 'Duration(Second)')) - -query I -SELECT extract(day from arrow_cast(NULL, 'Duration(Second)')) ----- -NULL - -# test_extract_date_part_func - -query B -SELECT (date_part('year', now()) = EXTRACT(year FROM now())) ----- -true - -query B -SELECT (date_part('quarter', now()) = EXTRACT(quarter FROM now())) ----- -true - -query B -SELECT (date_part('month', now()) = EXTRACT(month FROM now())) ----- -true - -query B -SELECT (date_part('week', now()) = EXTRACT(week FROM now())) ----- -true - -query B -SELECT (date_part('day', now()) = EXTRACT(day FROM now())) ----- -true - -query B -SELECT (date_part('hour', now()) = EXTRACT(hour FROM now())) ----- -true - -query B -SELECT (date_part('minute', now()) = EXTRACT(minute FROM now())) ----- -true - -query B -SELECT (date_part('second', now()) = EXTRACT(second FROM now())) ----- -true - -query B -SELECT (date_part('millisecond', now()) = EXTRACT(millisecond FROM now())) ----- -true - -query B -SELECT (date_part('microsecond', now()) = EXTRACT(microsecond FROM now())) ----- -true - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT (date_part('nanosecond', now()) = EXTRACT(nanosecond FROM now())) - query B SELECT 'a' IN ('a','b') ---- diff --git a/datafusion/sqllogictest/test_files/expr/date_part.slt b/datafusion/sqllogictest/test_files/expr/date_part.slt new file mode 100644 index 000000000000..cec80a165f30 --- /dev/null +++ b/datafusion/sqllogictest/test_files/expr/date_part.slt @@ -0,0 +1,878 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for `date_part` and `EXTRACT` (which is a different syntax +# for the same function). + +query error +SELECT EXTRACT("'''year'''" FROM timestamp '2020-09-08T12:00:00+00:00') + +query error +SELECT EXTRACT("'year'" FROM timestamp '2020-09-08T12:00:00+00:00') + +query I +SELECT date_part('YEAR', CAST('2000-01-01' AS DATE)) +---- +2000 + +query I +SELECT EXTRACT(year FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query I +SELECT EXTRACT("year" FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query I +SELECT EXTRACT('year' FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query I +SELECT date_part('QUARTER', CAST('2000-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(quarter FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +3 + +query I +SELECT EXTRACT("quarter" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +3 + +query I +SELECT EXTRACT('quarter' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +3 + +query I +SELECT date_part('MONTH', CAST('2000-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(month FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +9 + +query I +SELECT EXTRACT("month" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +9 + +query I +SELECT EXTRACT('month' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +9 + +query I +SELECT date_part('WEEK', CAST('2003-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(WEEK FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +37 + +query I +SELECT EXTRACT("WEEK" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +37 + +query I +SELECT EXTRACT('WEEK' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +37 + +query I +SELECT date_part('DAY', CAST('2000-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(day FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +8 + +query I +SELECT EXTRACT("day" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +8 + +query I +SELECT EXTRACT('day' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +8 + +query I +SELECT date_part('DOY', CAST('2000-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(doy FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +252 + +query I +SELECT EXTRACT("doy" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +252 + +query I +SELECT EXTRACT('doy' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +252 + +query I +SELECT date_part('DOW', CAST('2000-01-01' AS DATE)) +---- +6 + +query I +SELECT EXTRACT(dow FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +2 + +query I +SELECT EXTRACT("dow" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +2 + +query I +SELECT EXTRACT('dow' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +2 + +query I +SELECT date_part('HOUR', CAST('2000-01-01' AS DATE)) +---- +0 + +query I +SELECT EXTRACT(hour FROM to_timestamp('2020-09-08T12:03:03+00:00')) +---- +12 + +query I +SELECT EXTRACT("hour" FROM to_timestamp('2020-09-08T12:03:03+00:00')) +---- +12 + +query I +SELECT EXTRACT('hour' FROM to_timestamp('2020-09-08T12:03:03+00:00')) +---- +12 + +query I +SELECT EXTRACT(minute FROM to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query I +SELECT EXTRACT("minute" FROM to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query I +SELECT EXTRACT('minute' FROM to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query I +SELECT date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +# make sure the return type is integer +query T +SELECT arrow_typeof(date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00'))) +---- +Int32 + +query I +SELECT EXTRACT(second FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT EXTRACT(millisecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT EXTRACT(microsecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT EXTRACT(nanosecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') + +query I +SELECT EXTRACT("second" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT EXTRACT("millisecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT EXTRACT("microsecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT EXTRACT("nanosecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') + +query I +SELECT EXTRACT('second' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT EXTRACT('millisecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT EXTRACT('microsecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT EXTRACT('nanosecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') + + +# Keep precision when coercing Utf8 to Timestamp +query I +SELECT date_part('second', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT date_part('millisecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT date_part('microsecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT date_part('nanosecond', timestamp '2020-09-08T12:00:12.12345678+00:00') + + +query I +SELECT date_part('second', '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT date_part('millisecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT date_part('microsecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00') + +# test_date_part_time + +## time32 seconds +query I +SELECT date_part('hour', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +23 + +query I +SELECT extract(hour from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +23 + +query I +SELECT date_part('minute', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +32 + +query I +SELECT extract(minute from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +32 + +query I +SELECT date_part('second', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50 + +query I +SELECT extract(second from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50 + +query I +SELECT date_part('millisecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000 + +query I +SELECT extract(millisecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000 + +query I +SELECT date_part('microsecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000 + +query I +SELECT extract(microsecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT extract(nanosecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) + +query R +SELECT date_part('epoch', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +84770 + +query R +SELECT extract(epoch from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +84770 + +## time32 milliseconds +query I +SELECT date_part('hour', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +23 + +query I +SELECT extract(hour from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +23 + +query I +SELECT date_part('minute', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +32 + +query I +SELECT extract(minute from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +32 + +query I +SELECT date_part('second', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50 + +query I +SELECT extract(second from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50 + +query I +SELECT date_part('millisecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123 + +query I +SELECT extract(millisecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123 + +query I +SELECT date_part('microsecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000 + +query I +SELECT extract(microsecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT extract(nanosecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +84770.123 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +84770.123 + +## time64 microseconds +query I +SELECT date_part('hour', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +23 + +query I +SELECT extract(hour from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +23 + +query I +SELECT date_part('minute', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +32 + +query I +SELECT extract(minute from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +32 + +query I +SELECT date_part('second', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50 + +query I +SELECT extract(second from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50 + +query I +SELECT date_part('millisecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123 + +query I +SELECT extract(millisecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123 + +query I +SELECT date_part('microsecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456 + +query I +SELECT extract(microsecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT extract(nanosecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +84770.123456 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +84770.123456 + +## time64 nanoseconds +query I +SELECT date_part('hour', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +23 + +query I +SELECT extract(hour from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +23 + +query I +SELECT date_part('minute', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +32 + +query I +SELECT extract(minute from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +32 + +query I +SELECT date_part('second', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50 + +query I +select extract(second from '2024-08-09T12:13:14') +---- +14 + +query I +select extract(seconds from '2024-08-09T12:13:14') +---- +14 + +query I +SELECT extract(second from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50 + +query I +SELECT date_part('millisecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123 + +query I +SELECT extract(millisecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123 + +# just some floating point stuff happening in the result here +query I +SELECT date_part('microsecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456 + +query I +SELECT extract(microsecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456 + +query I +SELECT extract(us from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT date_part('nanosecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +84770.123456789 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +84770.123456789 + +# test_extract_epoch + +query R +SELECT extract(epoch from '1870-01-01T07:29:10.256'::timestamp) +---- +-3155646649.744 + +query R +SELECT extract(epoch from '2000-01-01T00:00:00.000'::timestamp) +---- +946684800 + +query R +SELECT extract(epoch from to_timestamp('2000-01-01T00:00:00+00:00')) +---- +946684800 + +query R +SELECT extract(epoch from NULL::timestamp) +---- +NULL + +query R +SELECT extract(epoch from arrow_cast('1970-01-01', 'Date32')) +---- +0 + +query R +SELECT extract(epoch from arrow_cast('1970-01-02', 'Date32')) +---- +86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-11', 'Date32')) +---- +864000 + +query R +SELECT extract(epoch from arrow_cast('1969-12-31', 'Date32')) +---- +-86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-01', 'Date64')) +---- +0 + +query R +SELECT extract(epoch from arrow_cast('1970-01-02', 'Date64')) +---- +86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-11', 'Date64')) +---- +864000 + +query R +SELECT extract(epoch from arrow_cast('1969-12-31', 'Date64')) +---- +-86400 + +# test_extract_interval + +query I +SELECT extract(year from arrow_cast('10 years', 'Interval(YearMonth)')) +---- +10 + +query I +SELECT extract(month from arrow_cast('10 years', 'Interval(YearMonth)')) +---- +0 + +query I +SELECT extract(year from arrow_cast('10 months', 'Interval(YearMonth)')) +---- +0 + +query I +SELECT extract(month from arrow_cast('10 months', 'Interval(YearMonth)')) +---- +10 + +query I +SELECT extract(year from arrow_cast('20 months', 'Interval(YearMonth)')) +---- +1 + +query I +SELECT extract(month from arrow_cast('20 months', 'Interval(YearMonth)')) +---- +8 + +query error DataFusion error: Arrow error: Compute error: Year does not support: Interval\(DayTime\) +SELECT extract(year from arrow_cast('10 days', 'Interval(DayTime)')) + +query error DataFusion error: Arrow error: Compute error: Month does not support: Interval\(DayTime\) +SELECT extract(month from arrow_cast('10 days', 'Interval(DayTime)')) + +query I +SELECT extract(day from arrow_cast('10 days', 'Interval(DayTime)')) +---- +10 + +query I +SELECT extract(day from arrow_cast('14400 minutes', 'Interval(DayTime)')) +---- +0 + +query I +SELECT extract(minute from arrow_cast('14400 minutes', 'Interval(DayTime)')) +---- +14400 + +query I +SELECT extract(second from arrow_cast('5.1 seconds', 'Interval(DayTime)')) +---- +5 + +query I +SELECT extract(second from arrow_cast('14400 minutes', 'Interval(DayTime)')) +---- +864000 + +query I +SELECT extract(second from arrow_cast('2 months', 'Interval(MonthDayNano)')) +---- +0 + +query I +SELECT extract(second from arrow_cast('2 days', 'Interval(MonthDayNano)')) +---- +0 + +query I +SELECT extract(second from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2 + +query I +SELECT extract(seconds from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2 + +query R +SELECT extract(epoch from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2 + +query I +SELECT extract(milliseconds from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2000 + +query I +SELECT extract(second from arrow_cast('2030 milliseconds', 'Interval(MonthDayNano)')) +---- +2 + +query I +SELECT extract(second from arrow_cast(NULL, 'Interval(MonthDayNano)')) +---- +NULL + +statement ok +create table t (id int, i interval) as values + (0, interval '5 months 1 day 10 nanoseconds'), + (1, interval '1 year 3 months'), + (2, interval '3 days 2 milliseconds'), + (3, interval '2 seconds'), + (4, interval '8 months'), + (5, NULL); + +query III +select + id, + extract(second from i), + extract(month from i) +from t +order by id; +---- +0 0 5 +1 0 15 +2 0 0 +3 2 0 +4 0 8 +5 NULL NULL + +statement ok +drop table t; + +# test_extract_duration + +query I +SELECT extract(second from arrow_cast(2, 'Duration(Second)')) +---- +2 + +query I +SELECT extract(seconds from arrow_cast(2, 'Duration(Second)')) +---- +2 + +query R +SELECT extract(epoch from arrow_cast(2, 'Duration(Second)')) +---- +2 + +query I +SELECT extract(millisecond from arrow_cast(2, 'Duration(Second)')) +---- +2000 + +query I +SELECT extract(second from arrow_cast(2, 'Duration(Millisecond)')) +---- +0 + +query I +SELECT extract(second from arrow_cast(2002, 'Duration(Millisecond)')) +---- +2 + +query I +SELECT extract(millisecond from arrow_cast(2002, 'Duration(Millisecond)')) +---- +2002 + +query I +SELECT extract(day from arrow_cast(864000, 'Duration(Second)')) +---- +10 + +query error DataFusion error: Arrow error: Compute error: Month does not support: Duration\(Second\) +SELECT extract(month from arrow_cast(864000, 'Duration(Second)')) + +query error DataFusion error: Arrow error: Compute error: Year does not support: Duration\(Second\) +SELECT extract(year from arrow_cast(864000, 'Duration(Second)')) + +query I +SELECT extract(day from arrow_cast(NULL, 'Duration(Second)')) +---- +NULL + +# test_extract_date_part_func + +query B +SELECT (date_part('year', now()) = EXTRACT(year FROM now())) +---- +true + +query B +SELECT (date_part('quarter', now()) = EXTRACT(quarter FROM now())) +---- +true + +query B +SELECT (date_part('month', now()) = EXTRACT(month FROM now())) +---- +true + +query B +SELECT (date_part('week', now()) = EXTRACT(week FROM now())) +---- +true + +query B +SELECT (date_part('day', now()) = EXTRACT(day FROM now())) +---- +true + +query B +SELECT (date_part('hour', now()) = EXTRACT(hour FROM now())) +---- +true + +query B +SELECT (date_part('minute', now()) = EXTRACT(minute FROM now())) +---- +true + +query B +SELECT (date_part('second', now()) = EXTRACT(second FROM now())) +---- +true + +query B +SELECT (date_part('millisecond', now()) = EXTRACT(millisecond FROM now())) +---- +true + +query B +SELECT (date_part('microsecond', now()) = EXTRACT(microsecond FROM now())) +---- +true + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT (date_part('nanosecond', now()) = EXTRACT(nanosecond FROM now())) From 2b65fb3add471cc23701b12ff9cb236f6715af75 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 11 Dec 2024 23:27:31 +0800 Subject: [PATCH 33/35] Support unparsing `UNNEST` plan to `UNNEST` table factor SQL (#13660) * add `unnest_as_table_factor` and `UnnestRelationBuilder` * unparse unnest as table factor * fix typo * add tests for the default configs * add a static const for unnest_placeholder * fix tests * fix tests --- datafusion/sql/src/unparser/ast.rs | 73 ++++++++++++++ datafusion/sql/src/unparser/dialect.rs | 23 +++++ datafusion/sql/src/unparser/plan.rs | 55 ++++++++++- datafusion/sql/src/unparser/utils.rs | 2 +- datafusion/sql/src/utils.rs | 50 +++++----- datafusion/sql/tests/cases/plan_to_sql.rs | 99 ++++++++++++++++++- .../sqllogictest/test_files/encoding.slt | 2 +- datafusion/sqllogictest/test_files/joins.slt | 12 +-- .../test_files/push_down_filter.slt | 40 ++++---- .../test_files/table_functions.slt | 2 +- datafusion/sqllogictest/test_files/unnest.slt | 28 +++--- datafusion/sqllogictest/test_files/window.slt | 1 - 12 files changed, 313 insertions(+), 74 deletions(-) diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index cc0812cd71e1..ad0b5f16b283 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -353,6 +353,7 @@ pub(super) struct RelationBuilder { enum TableFactorBuilder { Table(TableRelationBuilder), Derived(DerivedRelationBuilder), + Unnest(UnnestRelationBuilder), Empty, } @@ -369,6 +370,12 @@ impl RelationBuilder { self.relation = Some(TableFactorBuilder::Derived(value)); self } + + pub fn unnest(&mut self, value: UnnestRelationBuilder) -> &mut Self { + self.relation = Some(TableFactorBuilder::Unnest(value)); + self + } + pub fn empty(&mut self) -> &mut Self { self.relation = Some(TableFactorBuilder::Empty); self @@ -382,6 +389,9 @@ impl RelationBuilder { Some(TableFactorBuilder::Derived(ref mut rel_builder)) => { rel_builder.alias = value; } + Some(TableFactorBuilder::Unnest(ref mut rel_builder)) => { + rel_builder.alias = value; + } Some(TableFactorBuilder::Empty) => (), None => (), } @@ -391,6 +401,7 @@ impl RelationBuilder { Ok(match self.relation { Some(TableFactorBuilder::Table(ref value)) => Some(value.build()?), Some(TableFactorBuilder::Derived(ref value)) => Some(value.build()?), + Some(TableFactorBuilder::Unnest(ref value)) => Some(value.build()?), Some(TableFactorBuilder::Empty) => None, None => return Err(Into::into(UninitializedFieldError::from("relation"))), }) @@ -526,6 +537,68 @@ impl Default for DerivedRelationBuilder { } } +#[derive(Clone)] +pub(super) struct UnnestRelationBuilder { + pub alias: Option, + pub array_exprs: Vec, + with_offset: bool, + with_offset_alias: Option, + with_ordinality: bool, +} + +#[allow(dead_code)] +impl UnnestRelationBuilder { + pub fn alias(&mut self, value: Option) -> &mut Self { + self.alias = value; + self + } + pub fn array_exprs(&mut self, value: Vec) -> &mut Self { + self.array_exprs = value; + self + } + + pub fn with_offset(&mut self, value: bool) -> &mut Self { + self.with_offset = value; + self + } + + pub fn with_offset_alias(&mut self, value: Option) -> &mut Self { + self.with_offset_alias = value; + self + } + + pub fn with_ordinality(&mut self, value: bool) -> &mut Self { + self.with_ordinality = value; + self + } + + pub fn build(&self) -> Result { + Ok(ast::TableFactor::UNNEST { + alias: self.alias.clone(), + array_exprs: self.array_exprs.clone(), + with_offset: self.with_offset, + with_offset_alias: self.with_offset_alias.clone(), + with_ordinality: self.with_ordinality, + }) + } + + fn create_empty() -> Self { + Self { + alias: Default::default(), + array_exprs: Default::default(), + with_offset: Default::default(), + with_offset_alias: Default::default(), + with_ordinality: Default::default(), + } + } +} + +impl Default for UnnestRelationBuilder { + fn default() -> Self { + Self::create_empty() + } +} + /// Runtime error when a `build()` method is called and one or more required fields /// do not have a value. #[derive(Debug, Clone)] diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index e979d8fd4ebd..ae387d441fa2 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -157,6 +157,15 @@ pub trait Dialect: Send + Sync { fn full_qualified_col(&self) -> bool { false } + + /// Allow to unparse the unnest plan as [ast::TableFactor::UNNEST]. + /// + /// Some dialects like BigQuery require UNNEST to be used in the FROM clause but + /// the LogicalPlan planner always puts UNNEST in the SELECT clause. This flag allows + /// to unparse the UNNEST plan as [ast::TableFactor::UNNEST] instead of a subquery. + fn unnest_as_table_factor(&self) -> bool { + false + } } /// `IntervalStyle` to use for unparsing @@ -448,6 +457,7 @@ pub struct CustomDialect { requires_derived_table_alias: bool, division_operator: BinaryOperator, full_qualified_col: bool, + unnest_as_table_factor: bool, } impl Default for CustomDialect { @@ -474,6 +484,7 @@ impl Default for CustomDialect { requires_derived_table_alias: false, division_operator: BinaryOperator::Divide, full_qualified_col: false, + unnest_as_table_factor: false, } } } @@ -582,6 +593,10 @@ impl Dialect for CustomDialect { fn full_qualified_col(&self) -> bool { self.full_qualified_col } + + fn unnest_as_table_factor(&self) -> bool { + self.unnest_as_table_factor + } } /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern @@ -617,6 +632,7 @@ pub struct CustomDialectBuilder { requires_derived_table_alias: bool, division_operator: BinaryOperator, full_qualified_col: bool, + unnest_as_table_factor: bool, } impl Default for CustomDialectBuilder { @@ -649,6 +665,7 @@ impl CustomDialectBuilder { requires_derived_table_alias: false, division_operator: BinaryOperator::Divide, full_qualified_col: false, + unnest_as_table_factor: false, } } @@ -673,6 +690,7 @@ impl CustomDialectBuilder { requires_derived_table_alias: self.requires_derived_table_alias, division_operator: self.division_operator, full_qualified_col: self.full_qualified_col, + unnest_as_table_factor: self.unnest_as_table_factor, } } @@ -800,4 +818,9 @@ impl CustomDialectBuilder { self.full_qualified_col = full_qualified_col; self } + + pub fn with_unnest_as_table_factor(mut self, _unnest_as_table_factor: bool) -> Self { + self.unnest_as_table_factor = _unnest_as_table_factor; + self + } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index eaae4fe73d8c..e9f9f486ea9a 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -32,7 +32,9 @@ use super::{ }, Unparser, }; +use crate::unparser::ast::UnnestRelationBuilder; use crate::unparser::utils::unproject_agg_exprs; +use crate::utils::UNNEST_PLACEHOLDER; use datafusion_common::{ internal_err, not_impl_err, tree_node::{TransformedResult, TreeNode}, @@ -40,7 +42,7 @@ use datafusion_common::{ }; use datafusion_expr::{ expr::Alias, BinaryExpr, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, - LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan, + LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan, Unnest, }; use sqlparser::ast::{self, Ident, SetExpr}; use std::sync::Arc; @@ -312,6 +314,19 @@ impl Unparser<'_> { .select_to_sql_recursively(&new_plan, query, select, relation); } + // Projection can be top-level plan for unnest relation + // The projection generated by the `RecursiveUnnestRewriter` from a UNNEST relation will have + // only one expression, which is the placeholder column generated by the rewriter. + if self.dialect.unnest_as_table_factor() + && p.expr.len() == 1 + && Self::is_unnest_placeholder(&p.expr[0]) + { + if let LogicalPlan::Unnest(unnest) = &p.input.as_ref() { + return self + .unnest_to_table_factor_sql(unnest, query, select, relation); + } + } + // Projection can be top-level plan for derived table if select.already_projected() { return self.derive_with_dialect_alias( @@ -678,7 +693,11 @@ impl Unparser<'_> { ) } LogicalPlan::EmptyRelation(_) => { - relation.empty(); + // An EmptyRelation could be behind an UNNEST node. If the dialect supports UNNEST as a table factor, + // a TableRelationBuilder will be created for the UNNEST node first. + if !relation.has_relation() { + relation.empty(); + } Ok(()) } LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"), @@ -708,6 +727,38 @@ impl Unparser<'_> { } } + /// Try to find the placeholder column name generated by `RecursiveUnnestRewriter` + /// Only match the pattern `Expr::Alias(Expr::Column("__unnest_placeholder(...)"))` + fn is_unnest_placeholder(expr: &Expr) -> bool { + if let Expr::Alias(Alias { expr, .. }) = expr { + if let Expr::Column(Column { name, .. }) = expr.as_ref() { + return name.starts_with(UNNEST_PLACEHOLDER); + } + } + false + } + + fn unnest_to_table_factor_sql( + &self, + unnest: &Unnest, + query: &mut Option, + select: &mut SelectBuilder, + relation: &mut RelationBuilder, + ) -> Result<()> { + let mut unnest_relation = UnnestRelationBuilder::default(); + let LogicalPlan::Projection(p) = unnest.input.as_ref() else { + return internal_err!("Unnest input is not a Projection: {unnest:?}"); + }; + let exprs = p + .expr + .iter() + .map(|e| self.expr_to_sql(e)) + .collect::>>()?; + unnest_relation.array_exprs(exprs); + relation.unnest(unnest_relation); + self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) + } + fn is_scan_with_pushdown(scan: &TableScan) -> bool { scan.projection.is_some() || !scan.filters.is_empty() || scan.fetch.is_some() } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 518781106c3b..354a68f60964 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -133,7 +133,7 @@ pub(crate) fn find_window_nodes_within_select<'a>( /// Recursively identify Column expressions and transform them into the appropriate unnest expression /// -/// For example, if expr contains the column expr "unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" +/// For example, if expr contains the column expr "__unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" /// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL]) pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result { expr.transform(|sub_expr| { diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 69e3953341ef..1c2a3ea91a2b 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -315,6 +315,8 @@ pub(crate) fn rewrite_recursive_unnests_bottom_up( .collect::>()) } +pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder"; + /* This is only usedful when used with transform down up A full example of how the transformation works: @@ -360,9 +362,9 @@ impl RecursiveUnnestRewriter<'_> { // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection // inside unnest execution, each column inside the inner projection // will be transformed into new columns. Thus we need to keep track of these placeholding column names - let placeholder_name = format!("unnest_placeholder({})", inner_expr_name); + let placeholder_name = format!("{UNNEST_PLACEHOLDER}({})", inner_expr_name); let post_unnest_name = - format!("unnest_placeholder({},depth={})", inner_expr_name, level); + format!("{UNNEST_PLACEHOLDER}({},depth={})", inner_expr_name, level); // This is due to the fact that unnest transformation should keep the original // column name as is, to comply with group by and order by let placeholder_column = Column::from_name(placeholder_name.clone()); @@ -693,17 +695,17 @@ mod tests { // Only the bottom most unnest exprs are transformed assert_eq!( transformed_exprs, - vec![col("unnest_placeholder(3d_col,depth=2)") + vec![col("__unnest_placeholder(3d_col,depth=2)") .alias("UNNEST(UNNEST(3d_col))") .add( - col("unnest_placeholder(3d_col,depth=2)") + col("__unnest_placeholder(3d_col,depth=2)") .alias("UNNEST(UNNEST(3d_col))") ) .add(col("i64_col"))] ); column_unnests_eq( vec![ - "unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2]", + "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2]", ], &unnest_placeholder_columns, ); @@ -713,7 +715,7 @@ mod tests { assert_eq!( inner_projection_exprs, vec![ - col("3d_col").alias("unnest_placeholder(3d_col)"), + col("3d_col").alias("__unnest_placeholder(3d_col)"), col("i64_col") ] ); @@ -730,12 +732,12 @@ mod tests { assert_eq!( transformed_exprs, vec![ - (col("unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)")) + (col("__unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)")) .alias("2d_col") ] ); column_unnests_eq( - vec!["unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2, unnest_placeholder(3d_col,depth=1)|depth=1]"], + vec!["__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2, __unnest_placeholder(3d_col,depth=1)|depth=1]"], &unnest_placeholder_columns, ); // Still reference struct_col in original schema but with alias, @@ -743,7 +745,7 @@ mod tests { assert_eq!( inner_projection_exprs, vec![ - col("3d_col").alias("unnest_placeholder(3d_col)"), + col("3d_col").alias("__unnest_placeholder(3d_col)"), col("i64_col") ] ); @@ -794,19 +796,19 @@ mod tests { assert_eq!( transformed_exprs, vec![ - col("unnest_placeholder(struct_col).field1"), - col("unnest_placeholder(struct_col).field2"), + col("__unnest_placeholder(struct_col).field1"), + col("__unnest_placeholder(struct_col).field2"), ] ); column_unnests_eq( - vec!["unnest_placeholder(struct_col)"], + vec!["__unnest_placeholder(struct_col)"], &unnest_placeholder_columns, ); // Still reference struct_col in original schema but with alias, // to avoid colliding with the projection on the column itself if any assert_eq!( inner_projection_exprs, - vec![col("struct_col").alias("unnest_placeholder(struct_col)"),] + vec![col("struct_col").alias("__unnest_placeholder(struct_col)"),] ); // unnest(array_col) + 1 @@ -819,15 +821,15 @@ mod tests { )?; column_unnests_eq( vec![ - "unnest_placeholder(struct_col)", - "unnest_placeholder(array_col)=>[unnest_placeholder(array_col,depth=1)|depth=1]", + "__unnest_placeholder(struct_col)", + "__unnest_placeholder(array_col)=>[__unnest_placeholder(array_col,depth=1)|depth=1]", ], &unnest_placeholder_columns, ); // Only transform the unnest children assert_eq!( transformed_exprs, - vec![col("unnest_placeholder(array_col,depth=1)") + vec![col("__unnest_placeholder(array_col,depth=1)") .alias("UNNEST(array_col)") .add(lit(1i64))] ); @@ -838,8 +840,8 @@ mod tests { assert_eq!( inner_projection_exprs, vec![ - col("struct_col").alias("unnest_placeholder(struct_col)"), - col("array_col").alias("unnest_placeholder(array_col)") + col("struct_col").alias("__unnest_placeholder(struct_col)"), + col("array_col").alias("__unnest_placeholder(array_col)") ] ); @@ -907,7 +909,7 @@ mod tests { assert_eq!( transformed_exprs, vec![unnest( - col("unnest_placeholder(struct_list,depth=1)") + col("__unnest_placeholder(struct_list,depth=1)") .alias("UNNEST(struct_list)") .field("subfield1") )] @@ -915,14 +917,14 @@ mod tests { column_unnests_eq( vec![ - "unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]", + "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]", ], &unnest_placeholder_columns, ); assert_eq!( inner_projection_exprs, - vec![col("struct_list").alias("unnest_placeholder(struct_list)")] + vec![col("struct_list").alias("__unnest_placeholder(struct_list)")] ); // continue rewrite another expr in select @@ -937,7 +939,7 @@ mod tests { assert_eq!( transformed_exprs, vec![unnest( - col("unnest_placeholder(struct_list,depth=1)") + col("__unnest_placeholder(struct_list,depth=1)") .alias("UNNEST(struct_list)") .field("subfield2") )] @@ -947,14 +949,14 @@ mod tests { // because expr1 and expr2 derive from the same unnest result column_unnests_eq( vec![ - "unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]", + "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]", ], &unnest_placeholder_columns, ); assert_eq!( inner_projection_exprs, - vec![col("struct_list").alias("unnest_placeholder(struct_list)")] + vec![col("struct_list").alias("__unnest_placeholder(struct_list)")] ); Ok(()) diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index fcfee29f6ac9..236b59432a5f 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -525,6 +525,96 @@ fn roundtrip_statement_with_dialect() -> Result<()> { parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(SqliteDialect {}), }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3])", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))")"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS t1 (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS t1 (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]), j1", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") CROSS JOIN j1"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) JOIN j1 ON (u.c1 = j1.j1_id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) UNION ALL SELECT * FROM (SELECT UNNEST([4, 5, 6]) AS "UNNEST(make_array(Int64(4),Int64(5),Int64(6)))") AS u (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3])", + expected: r#"SELECT * FROM UNNEST([1, 2, 3])"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS t1 (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS t1 (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]), j1", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) CROSS JOIN j1"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS u (c1) JOIN j1 ON (u.c1 = j1.j1_id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS u (c1) UNION ALL SELECT * FROM UNNEST([4, 5, 6]) AS u (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT UNNEST([1,2,3])", + expected: r#"SELECT * FROM UNNEST([1, 2, 3])"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT UNNEST([1,2,3]) as c1", + expected: r#"SELECT UNNEST([1, 2, 3]) AS c1"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT UNNEST([1,2,3]), 1", + expected: r#"SELECT UNNEST([1, 2, 3]) AS UNNEST(make_array(Int64(1),Int64(2),Int64(3))), Int64(1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, ]; for query in tests { @@ -535,7 +625,8 @@ fn roundtrip_statement_with_dialect() -> Result<()> { let state = MockSessionState::default() .with_aggregate_function(max_udaf()) .with_aggregate_function(min_udaf()) - .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())) + .with_expr_planner(Arc::new(NestedFunctionPlanner)); let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new(&context); @@ -571,9 +662,9 @@ fn test_unnest_logical_plan() -> Result<()> { let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); let expected = r#" -Projection: unnest_placeholder(unnest_table.struct_col).field1, unnest_placeholder(unnest_table.struct_col).field2, unnest_placeholder(unnest_table.array_col,depth=1) AS UNNEST(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col - Unnest: lists[unnest_placeholder(unnest_table.array_col)|depth=1] structs[unnest_placeholder(unnest_table.struct_col)] - Projection: unnest_table.struct_col AS unnest_placeholder(unnest_table.struct_col), unnest_table.array_col AS unnest_placeholder(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col +Projection: __unnest_placeholder(unnest_table.struct_col).field1, __unnest_placeholder(unnest_table.struct_col).field2, __unnest_placeholder(unnest_table.array_col,depth=1) AS UNNEST(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col + Unnest: lists[__unnest_placeholder(unnest_table.array_col)|depth=1] structs[__unnest_placeholder(unnest_table.struct_col)] + Projection: unnest_table.struct_col AS __unnest_placeholder(unnest_table.struct_col), unnest_table.array_col AS __unnest_placeholder(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col TableScan: unnest_table"#.trim_start(); assert_eq!(plan.to_string(), expected); diff --git a/datafusion/sqllogictest/test_files/encoding.slt b/datafusion/sqllogictest/test_files/encoding.slt index fc22cc8bf7a7..24efb33f7896 100644 --- a/datafusion/sqllogictest/test_files/encoding.slt +++ b/datafusion/sqllogictest/test_files/encoding.slt @@ -101,4 +101,4 @@ FROM test_utf8view; Andrew QW5kcmV3 416e64726577 X WA 58 Xiangpeng WGlhbmdwZW5n 5869616e6770656e67 Xiangpeng WGlhbmdwZW5n 5869616e6770656e67 Raphael UmFwaGFlbA 5261706861656c R Ug 52 -NULL NULL NULL R Ug 52 \ No newline at end of file +NULL NULL NULL R Ug 52 diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 62f625119897..49aaa877caa6 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4058,9 +4058,9 @@ logical_plan 03)----TableScan: join_t1 projection=[t1_id, t1_name] 04)--SubqueryAlias: series 05)----Subquery: -06)------Projection: unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)),depth=1) AS i -07)--------Unnest: lists[unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)))|depth=1] structs[] -08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t1.t1_int) AS Int64)) AS unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int))) +06)------Projection: __unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)),depth=1) AS i +07)--------Unnest: lists[__unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)))|depth=1] structs[] +08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t1.t1_int) AS Int64)) AS __unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int))) 09)------------EmptyRelation physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t1" }), name: "t1_int" }) @@ -4081,9 +4081,9 @@ logical_plan 03)----TableScan: join_t1 projection=[t1_id, t1_name] 04)--SubqueryAlias: series 05)----Subquery: -06)------Projection: unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)),depth=1) AS i -07)--------Unnest: lists[unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)))|depth=1] structs[] -08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t2.t1_int) AS Int64)) AS unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int))) +06)------Projection: __unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)),depth=1) AS i +07)--------Unnest: lists[__unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)))|depth=1] structs[] +08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t2.t1_int) AS Int64)) AS __unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int))) 09)------------EmptyRelation physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t2" }), name: "t1_int" }) diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 86aa07b04ce1..64cc51b3c4ff 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -36,9 +36,9 @@ query TT explain select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = 2; ---- logical_plan -01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2 -02)--Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] -03)----Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2 +02)--Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] +03)----Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 04)------Filter: v.column1 = Int64(2) 05)--------TableScan: v projection=[column1, column2] @@ -53,11 +53,11 @@ query TT explain select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; ---- logical_plan -01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2 -02)--Filter: unnest_placeholder(v.column2,depth=1) > Int64(3) -03)----Projection: unnest_placeholder(v.column2,depth=1) -04)------Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] -05)--------Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2 +02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) +03)----Projection: __unnest_placeholder(v.column2,depth=1) +04)------Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] +05)--------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 06)----------TableScan: v projection=[column1, column2] query II @@ -71,10 +71,10 @@ query TT explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; ---- logical_plan -01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 -02)--Filter: unnest_placeholder(v.column2,depth=1) > Int64(3) -03)----Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] -04)------Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 +02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) +03)----Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] +04)------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 05)--------Filter: v.column1 = Int64(2) 06)----------TableScan: v projection=[column1, column2] @@ -90,10 +90,10 @@ query TT explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; ---- logical_plan -01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 -02)--Filter: unnest_placeholder(v.column2,depth=1) > Int64(3) OR v.column1 = Int64(2) -03)----Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] -04)------Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 +02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) OR v.column1 = Int64(2) +03)----Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] +04)------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 05)--------TableScan: v projection=[column1, column2] statement ok @@ -112,10 +112,10 @@ query TT explain select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; ---- logical_plan -01)Projection: d.column1, unnest_placeholder(d.column2,depth=1) AS o -02)--Filter: get_field(unnest_placeholder(d.column2,depth=1), Utf8("a")) = Int64(1) -03)----Unnest: lists[unnest_placeholder(d.column2)|depth=1] structs[] -04)------Projection: d.column1, d.column2 AS unnest_placeholder(d.column2) +01)Projection: d.column1, __unnest_placeholder(d.column2,depth=1) AS o +02)--Filter: get_field(__unnest_placeholder(d.column2,depth=1), Utf8("a")) = Int64(1) +03)----Unnest: lists[__unnest_placeholder(d.column2)|depth=1] structs[] +04)------Projection: d.column1, d.column2 AS __unnest_placeholder(d.column2) 05)--------TableScan: d projection=[column1, column2] diff --git a/datafusion/sqllogictest/test_files/table_functions.slt b/datafusion/sqllogictest/test_files/table_functions.slt index 12402e0d70c5..79294993dded 100644 --- a/datafusion/sqllogictest/test_files/table_functions.slt +++ b/datafusion/sqllogictest/test_files/table_functions.slt @@ -139,4 +139,4 @@ SELECT generate_series(1, t1.end) FROM generate_series(3, 5) as t1(end) ---- [1, 2, 3, 4, 5] [1, 2, 3, 4] -[1, 2, 3] \ No newline at end of file +[1, 2, 3] diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index d409e0902f7e..1c54006bd2a0 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -594,17 +594,17 @@ query TT explain select unnest(unnest(column3)), column3 from recursive_unnest_table; ---- logical_plan -01)Unnest: lists[] structs[unnest_placeholder(UNNEST(recursive_unnest_table.column3))] -02)--Projection: unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3) AS unnest_placeholder(UNNEST(recursive_unnest_table.column3)), recursive_unnest_table.column3 -03)----Unnest: lists[unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] -04)------Projection: recursive_unnest_table.column3 AS unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 +01)Unnest: lists[] structs[__unnest_placeholder(UNNEST(recursive_unnest_table.column3))] +02)--Projection: __unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3) AS __unnest_placeholder(UNNEST(recursive_unnest_table.column3)), recursive_unnest_table.column3 +03)----Unnest: lists[__unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] +04)------Projection: recursive_unnest_table.column3 AS __unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 05)--------TableScan: recursive_unnest_table projection=[column3] physical_plan 01)UnnestExec 02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -03)----ProjectionExec: expr=[unnest_placeholder(recursive_unnest_table.column3,depth=1)@0 as unnest_placeholder(UNNEST(recursive_unnest_table.column3)), column3@1 as column3] +03)----ProjectionExec: expr=[__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0 as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)), column3@1 as column3] 04)------UnnestExec -05)--------ProjectionExec: expr=[column3@0 as unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] +05)--------ProjectionExec: expr=[column3@0 as __unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] 06)----------MemoryExec: partitions=1, partition_sizes=[1] ## unnest->field_access->unnest->unnest @@ -650,19 +650,19 @@ query TT explain select unnest(unnest(unnest(column3)['c1'])), column3 from recursive_unnest_table; ---- logical_plan -01)Projection: unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2) AS UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), recursive_unnest_table.column3 -02)--Unnest: lists[unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1])|depth=2] structs[] -03)----Projection: get_field(unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3), Utf8("c1")) AS unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), recursive_unnest_table.column3 -04)------Unnest: lists[unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] -05)--------Projection: recursive_unnest_table.column3 AS unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 +01)Projection: __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2) AS UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), recursive_unnest_table.column3 +02)--Unnest: lists[__unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1])|depth=2] structs[] +03)----Projection: get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3), Utf8("c1")) AS __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), recursive_unnest_table.column3 +04)------Unnest: lists[__unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] +05)--------Projection: recursive_unnest_table.column3 AS __unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 06)----------TableScan: recursive_unnest_table projection=[column3] physical_plan -01)ProjectionExec: expr=[unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2)@0 as UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), column3@1 as column3] +01)ProjectionExec: expr=[__unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2)@0 as UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), column3@1 as column3] 02)--UnnestExec -03)----ProjectionExec: expr=[get_field(unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] +03)----ProjectionExec: expr=[get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------UnnestExec -06)----------ProjectionExec: expr=[column3@0 as unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] +06)----------ProjectionExec: expr=[column3@0 as __unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] 07)------------MemoryExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 6c48ac68ab6b..188e2ae0915f 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -5127,4 +5127,3 @@ order by id; statement ok drop table t1; - From 79cb7d69e6fe8303cfd24fedd53ee51c424fc013 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Sun, 15 Dec 2024 23:10:31 +0800 Subject: [PATCH 34/35] feat: new way to make bool_buffer in scalar_regex_match --- .../benches/scalar_regex_match.rs | 8 +- .../src/expressions/scalar_regex_match.rs | 78 ++++++++++--------- 2 files changed, 48 insertions(+), 38 deletions(-) diff --git a/datafusion/physical-expr/benches/scalar_regex_match.rs b/datafusion/physical-expr/benches/scalar_regex_match.rs index 9c6826800600..99dbe100507d 100644 --- a/datafusion/physical-expr/benches/scalar_regex_match.rs +++ b/datafusion/physical-expr/benches/scalar_regex_match.rs @@ -16,6 +16,7 @@ // under the License. use std::sync::Arc; +use std::time::Duration; use arrow_array::{RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; @@ -70,7 +71,7 @@ fn init_benchmark() -> ( let column = "s"; let schema = Schema::new(vec![Field::new(column, DataType::Utf8, true)]); - // meke test record batch + // make test record batch let batch_data = vec![ // (20, 10_usize, make_record_batch(20, 10, 100, schema.clone())), // (20, 100_usize, make_record_batch(20, 100, 100, schema.clone())), @@ -131,6 +132,11 @@ fn regex_match_benchmark(c: &mut Criterion) { name, batch_iter, batch_size ); let mut group = c.benchmark_group(group_name.as_str()); + + group + .sample_size(50) + .measurement_time(Duration::new(30,0)); + // binary expr match benchmarks group.bench_function("binary_expr_match", |b| { b.iter(|| { diff --git a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs index b4e1e92306cb..1f7be76a95b2 100644 --- a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs +++ b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs @@ -19,7 +19,7 @@ use super::Literal; use arrow_array::{ Array, ArrayAccessor, BooleanArray, LargeStringArray, StringArray, StringViewArray, }; -use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder}; +use arrow_buffer::{bit_util, BooleanBuffer, MutableBuffer}; use arrow_schema::{DataType, Schema}; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::ColumnarValue; @@ -298,13 +298,17 @@ fn array_regexp_match( let bool_buffer = if regex.as_str().is_empty() { BooleanBuffer::new_set(array.len()) } else { - let mut bool_buffer_builder = BooleanBufferBuilder::new(array.len()); - bool_buffer_builder.advance(array.len()); + let mut mutable_buffer = MutableBuffer::new(0); for i in 0..array.len() { let value = unsafe { array.value_unchecked(i) }; - bool_buffer_builder.set_bit(i, regex.is_match(value)); + if i % 8 == 0 { + mutable_buffer.push(0u8); + } + if regex.is_match(value) { + unsafe { bit_util::set_bit_raw(mutable_buffer.as_mut_ptr(), i) }; + } } - bool_buffer_builder.finish() + BooleanBuffer::new(mutable_buffer.into(), 0, array.len()) }; let bool_array = BooleanArray::new(bool_buffer, null_buffer); @@ -378,99 +382,99 @@ mod tests { negated, case_insensitive, typ, a_vec, b_lit, c_vec, case( false, false, DataType::Utf8, - Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), "^a", - Arc::new(BooleanArray::from(vec![true, false, false, false, false])), + Arc::new(BooleanArray::from(vec![true, false, false, false, false, true, false, false, false, false])), ), case( false, true, DataType::Utf8, - Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), "^a", - Arc::new(BooleanArray::from(vec![true, false, true, false, false])), + Arc::new(BooleanArray::from(vec![true, false, true, false, false, true, false, true, false, false])), ), case( true, false, DataType::Utf8, - Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), "^a", - Arc::new(BooleanArray::from(vec![false, true, true, true, true])), + Arc::new(BooleanArray::from(vec![false, true, true, true, true, false, true, true, true, true])), ), case( true, true, DataType::Utf8, - Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), "^a", - Arc::new(BooleanArray::from(vec![false, true, false, true, true])), + Arc::new(BooleanArray::from(vec![false, true, false, true, true, false, true, false, true, true])), ), case( true, true, DataType::Utf8, - Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), ScalarValue::Utf8(None), - Arc::new(BooleanArray::from(vec![None, None, None, None, None])), + Arc::new(BooleanArray::from(vec![None, None, None, None, None, None, None, None, None, None])), ), case( false, false, DataType::LargeUtf8, - Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), ScalarValue::LargeUtf8(Some("^a".to_string())), - Arc::new(BooleanArray::from(vec![true, false, false, false, false])), + Arc::new(BooleanArray::from(vec![true, false, false, false, false, true, false, false, false, false])), ), case( false, true, DataType::LargeUtf8, - Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), ScalarValue::LargeUtf8(Some("^a".to_string())), - Arc::new(BooleanArray::from(vec![true, false, true, false, false])), + Arc::new(BooleanArray::from(vec![true, false, true, false, false, true, false, true, false, false])), ), case( true, false, DataType::LargeUtf8, - Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), ScalarValue::LargeUtf8(Some("^a".to_string())), - Arc::new(BooleanArray::from(vec![false, true, true, true, true])), + Arc::new(BooleanArray::from(vec![false, true, true, true, true, false, true, true, true, true])), ), case( true, true, DataType::LargeUtf8, - Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), ScalarValue::LargeUtf8(Some("^a".to_string())), - Arc::new(BooleanArray::from(vec![false, true, false, true, true])), + Arc::new(BooleanArray::from(vec![false, true, false, true, true, false, true, false, true, true])), ), case( true, true, DataType::LargeUtf8, - Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), ScalarValue::LargeUtf8(None), - Arc::new(BooleanArray::from(vec![None, None, None, None, None])), + Arc::new(BooleanArray::from(vec![None, None, None, None, None, None, None, None, None, None])), ), case( false, false, DataType::Utf8View, - Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), ScalarValue::Utf8View(Some("^a".to_string())), - Arc::new(BooleanArray::from(vec![true, false, false, false, false])), + Arc::new(BooleanArray::from(vec![true, false, false, false, false, true, false, false, false, false])), ), case( false, true, DataType::Utf8View, - Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), ScalarValue::Utf8View(Some("^a".to_string())), - Arc::new(BooleanArray::from(vec![true, false, true, false, false])), + Arc::new(BooleanArray::from(vec![true, false, true, false, false, true, false, true, false, false])), ), case( true, false, DataType::Utf8View, - Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), ScalarValue::Utf8View(Some("^a".to_string())), - Arc::new(BooleanArray::from(vec![false, true, true, true, true])), + Arc::new(BooleanArray::from(vec![false, true, true, true, true, false, true, true, true, true])), ), case( true, true, DataType::Utf8View, - Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), ScalarValue::Utf8View(Some("^a".to_string())), - Arc::new(BooleanArray::from(vec![false, true, false, true, true])), + Arc::new(BooleanArray::from(vec![false, true, false, true, true, false, true, false, true, true])), ), case( true, true, DataType::Utf8View, - Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), ScalarValue::Utf8View(None), - Arc::new(BooleanArray::from(vec![None, None, None, None, None])), + Arc::new(BooleanArray::from(vec![None, None, None, None, None, None, None, None, None, None])), ), case( true, true, DataType::Null, - Arc::new(NullArray::new(5)), + Arc::new(NullArray::new(10)), ScalarValue::Utf8View(Some("^a".to_string())), - Arc::new(BooleanArray::from(vec![None, None, None, None, None])), + Arc::new(BooleanArray::from(vec![None, None, None, None, None, None, None, None, None, None])), ), )] fn test_scalar_regex_match_array( From c53ed7728a188ad42e1878896b74a19e86dcef8e Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Mon, 16 Dec 2024 00:59:19 +0800 Subject: [PATCH 35/35] fix: take fmt suggestion --- datafusion/physical-expr/benches/scalar_regex_match.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/physical-expr/benches/scalar_regex_match.rs b/datafusion/physical-expr/benches/scalar_regex_match.rs index 99dbe100507d..0526d48daf46 100644 --- a/datafusion/physical-expr/benches/scalar_regex_match.rs +++ b/datafusion/physical-expr/benches/scalar_regex_match.rs @@ -133,9 +133,7 @@ fn regex_match_benchmark(c: &mut Criterion) { ); let mut group = c.benchmark_group(group_name.as_str()); - group - .sample_size(50) - .measurement_time(Duration::new(30,0)); + group.sample_size(50).measurement_time(Duration::new(30, 0)); // binary expr match benchmarks group.bench_function("binary_expr_match", |b| {