Skip to content

Commit

Permalink
Preserve constant values across union operations (#13805)
Browse files Browse the repository at this point in the history
* Add value tracking to ConstExpr for improved union optimization

* Update PartialEq impl

* Minor change

* Add docstring for ConstExpr value

* Improve constant propagation across union partitions

* Add assertion for across_partitions

* fix fmt

* Update properties.rs

* Remove redundant constant removal loop

* Remove unnecessary mut

* Set across_partitions=true when both sides are constant

* Extract and use constant values in filter expressions

* Add initial SLT for constant value tracking across UNION ALL

* Assign values to ConstExpr where possible

* Revert "Set across_partitions=true when both sides are constant"

This reverts commit 3051cd4.

* Temporarily take value from literal

* Lint fixes

* Cargo fmt

* Add get_expr_constant_value

* Make `with_value()` accept optional value

* Add todo

* Move test to union.slt

* Fix changed slt after merge

* Simplify constexpr

* Update properties.rs

---------

Co-authored-by: berkaysynnada <berkay.sahin@synnada.ai>
  • Loading branch information
gokselk and berkaysynnada authored Dec 25, 2024
1 parent 482b489 commit b9cef8c
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 90 deletions.
64 changes: 51 additions & 13 deletions datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::fmt::Display;
use std::sync::Arc;

use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::JoinType;
use datafusion_common::{JoinType, ScalarValue};
use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;

use indexmap::{IndexMap, IndexSet};
Expand Down Expand Up @@ -55,13 +55,45 @@ use indexmap::{IndexMap, IndexSet};
/// // create a constant expression from a physical expression
/// let const_expr = ConstExpr::from(col);
/// ```
// TODO: Consider refactoring the `across_partitions` and `value` fields into an enum:
//
// ```
// enum PartitionValues {
// Uniform(Option<ScalarValue>), // Same value across all partitions
// Heterogeneous(Vec<Option<ScalarValue>>) // Different values per partition
// }
// ```
//
// This would provide more flexible representation of partition values.
// Note: This is a breaking change for the equivalence API and should be
// addressed in a separate issue/PR.
#[derive(Debug, Clone)]
pub struct ConstExpr {
/// The expression that is known to be constant (e.g. a `Column`)
expr: Arc<dyn PhysicalExpr>,
/// Does the constant have the same value across all partitions? See
/// struct docs for more details
across_partitions: bool,
across_partitions: AcrossPartitions,
}

#[derive(PartialEq, Clone, Debug)]
/// Represents whether a constant expression's value is uniform or varies across partitions.
///
/// The `AcrossPartitions` enum is used to describe the nature of a constant expression
/// in a physical execution plan:
///
/// - `Heterogeneous`: The constant expression may have different values for different partitions.
/// - `Uniform(Option<ScalarValue>)`: The constant expression has the same value across all partitions,
/// or is `None` if the value is not specified.
pub enum AcrossPartitions {
Heterogeneous,
Uniform(Option<ScalarValue>),
}

impl Default for AcrossPartitions {
fn default() -> Self {
Self::Heterogeneous
}
}

impl PartialEq for ConstExpr {
Expand All @@ -79,23 +111,23 @@ impl ConstExpr {
Self {
expr,
// By default, assume constant expressions are not same across partitions.
across_partitions: false,
across_partitions: Default::default(),
}
}

/// Set the `across_partitions` flag
///
/// See struct docs for more details
pub fn with_across_partitions(mut self, across_partitions: bool) -> Self {
pub fn with_across_partitions(mut self, across_partitions: AcrossPartitions) -> Self {
self.across_partitions = across_partitions;
self
}

/// Is the expression the same across all partitions?
///
/// See struct docs for more details
pub fn across_partitions(&self) -> bool {
self.across_partitions
pub fn across_partitions(&self) -> AcrossPartitions {
self.across_partitions.clone()
}

pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
Expand All @@ -113,7 +145,7 @@ impl ConstExpr {
let maybe_expr = f(&self.expr);
maybe_expr.map(|expr| Self {
expr,
across_partitions: self.across_partitions,
across_partitions: self.across_partitions.clone(),
})
}

Expand Down Expand Up @@ -143,14 +175,20 @@ impl ConstExpr {
}
}

/// Display implementation for `ConstExpr`
///
/// Example `c` or `c(across_partitions)`
impl Display for ConstExpr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.expr)?;
if self.across_partitions {
write!(f, "(across_partitions)")?;
match &self.across_partitions {
AcrossPartitions::Heterogeneous => {
write!(f, "(heterogeneous)")?;
}
AcrossPartitions::Uniform(value) => {
if let Some(val) = value {
write!(f, "(uniform: {})", val)?;
} else {
write!(f, "(uniform: unknown)")?;
}
}
}
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/equivalence/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ mod ordering;
mod projection;
mod properties;

pub use class::{ConstExpr, EquivalenceClass, EquivalenceGroup};
pub use class::{AcrossPartitions, ConstExpr, EquivalenceClass, EquivalenceGroup};
pub use ordering::OrderingEquivalenceClass;
pub use projection::ProjectionMapping;
pub use properties::{
Expand Down
9 changes: 5 additions & 4 deletions datafusion/physical-expr/src/equivalence/ordering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ mod tests {
};
use crate::expressions::{col, BinaryExpr, Column};
use crate::utils::tests::TestScalarUDF;
use crate::{ConstExpr, PhysicalExpr, PhysicalSortExpr};
use crate::{AcrossPartitions, ConstExpr, PhysicalExpr, PhysicalSortExpr};

use arrow::datatypes::{DataType, Field, Schema};
use arrow_schema::SortOptions;
Expand Down Expand Up @@ -583,9 +583,10 @@ mod tests {
let eq_group = EquivalenceGroup::new(eq_group);
eq_properties.add_equivalence_group(eq_group);

let constants = constants
.into_iter()
.map(|expr| ConstExpr::from(expr).with_across_partitions(true));
let constants = constants.into_iter().map(|expr| {
ConstExpr::from(expr)
.with_across_partitions(AcrossPartitions::Uniform(None))
});
eq_properties = eq_properties.with_constants(constants);

let reqs = convert_to_sort_exprs(&reqs);
Expand Down
Loading

0 comments on commit b9cef8c

Please sign in to comment.