Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate example: simplify_udaf_expression.rs into advanced_udaf.rs #13905

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 132 additions & 53 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
expr::AggregateFunction,
function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs},
simplify::SimplifyInfo,
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};

Expand Down Expand Up @@ -197,40 +199,6 @@ impl Accumulator for GeometricMean {
}
}

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]));

// define data in two partitions
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
],
)?;
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![64.0])),
Arc::new(Float32Array::from(vec![2.0])),
],
)?;

// declare a new context. In spark API, this corresponds to a new spark SQLsession
let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;
Ok(ctx)
}

// Define a `GroupsAccumulator` for GeometricMean
/// which handles accumulator state for multiple groups at once.
/// This API is significantly more complicated than `Accumulator`, which manages
Expand Down Expand Up @@ -399,35 +367,146 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
}
}

/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
/// defined aggregate function with a different expression which is defined in the `simplify` method.
#[derive(Debug, Clone)]
struct SimplifiedGeoMeanUdaf {
signature: Signature,
}

impl SimplifiedGeoMeanUdaf {
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for SimplifiedGeoMeanUdaf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"simplified_geo_mean"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
unimplemented!("should not be invoked")
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
unimplemented!("should not be invoked")
}

fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}

fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!("should not get here");
}

/// Optionally replaces a UDAF with another expression during query optimization.
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| {
// Replaces the UDAF with `GeoMeanUdaf` as a placeholder example to demonstrate the `simplify` method.
// In real-world scenarios, you might create UDFs from built-in expressions.
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(AggregateUDF::from(GeoMeanUdaf::new())),
aggregate_function.args,
aggregate_function.distinct,
aggregate_function.filter,
aggregate_function.order_by,
aggregate_function.null_treatment,
)))
};
Some(Box::new(simplify))
}
}

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]));

// define data in two partitions
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
],
)?;
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![64.0])),
Arc::new(Float32Array::from(vec![2.0])),
],
)?;

// declare a new context. In spark API, this corresponds to a new spark SQLsession
let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;
Ok(ctx)
}

#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;

// create the AggregateUDF
let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new());
ctx.register_udaf(geometric_mean.clone());
let geo_mean_udf = AggregateUDF::from(GeoMeanUdaf::new());
let simplified_geo_mean_udf = AggregateUDF::from(SimplifiedGeoMeanUdaf::new());

for (udf, udf_name) in [
(geo_mean_udf, "geo_mean"),
(simplified_geo_mean_udf, "simplified_geo_mean"),
] {
ctx.register_udaf(udf.clone());

let sql_df = ctx.sql("SELECT geo_mean(a) FROM t group by b").await?;
sql_df.show().await?;
let sql_df = ctx
.sql(&format!("SELECT {}(a) FROM t GROUP BY b", udf_name))
.await?;
sql_df.show().await?;

// get a DataFrame from the context
// this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0.
let df = ctx.table("t").await?;
// get a DataFrame from the context
// this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0.
let df = ctx.table("t").await?;

// perform the aggregation
let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?;
// perform the aggregation
let df = df.aggregate(vec![], vec![udf.call(vec![col("a")])])?;

// note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature.
// note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature.

// execute the query
let results = df.collect().await?;
// execute the query
let results = df.collect().await?;

// downcast the array to the expected type
let result = as_float64_array(results[0].column(0))?;
// downcast the array to the expected type
let result = as_float64_array(results[0].column(0))?;

// verify that the calculation is correct
assert!((result.value(0) - 8.0).abs() < f64::EPSILON);
println!("The geometric mean of [2,4,8,64] is {}", result.value(0));
// verify that the calculation is correct
assert!((result.value(0) - 8.0).abs() < f64::EPSILON);
println!("The geometric mean of [2,4,8,64] is {}", result.value(0));
}

Ok(())
}
176 changes: 0 additions & 176 deletions datafusion-examples/examples/simplify_udaf_expression.rs

This file was deleted.

Loading