Skip to content

Commit

Permalink
refactor: DfUdfAdapter to bridge ScalaUdf
Browse files Browse the repository at this point in the history
Signed-off-by: tison <wander4096@gmail.com>
  • Loading branch information
tisonkun committed Apr 26, 2024
1 parent e410192 commit 1010e6f
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 71 deletions.
117 changes: 57 additions & 60 deletions src/common/query/src/logical_plan/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,74 +91,71 @@ impl AggregateFunction {
}
}

impl From<AggregateFunction> for DfAggregateUdf {
fn from(udaf: AggregateFunction) -> Self {
struct DfUdafAdapter {
name: String,
signature: datafusion_expr::Signature,
return_type_func: datafusion_expr::ReturnTypeFunction,
accumulator: AccumulatorFactoryFunction,
creator: AggregateFunctionCreatorRef,
}
struct DfUdafAdapter {
name: String,
signature: datafusion_expr::Signature,
return_type_func: datafusion_expr::ReturnTypeFunction,
accumulator: AccumulatorFactoryFunction,
creator: AggregateFunctionCreatorRef,
}

impl Debug for DfUdafAdapter {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("DfUdafAdapter")
.field("name", &self.name)
.field("signature", &self.signature)
.finish()
}
}
impl Debug for DfUdafAdapter {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("DfUdafAdapter")
.field("name", &self.name)
.field("signature", &self.signature)
.finish()
}
}

impl AggregateUDFImpl for DfUdafAdapter {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &datafusion_expr::Signature {
&self.signature
}

fn return_type(&self, arg_types: &[ArrowDataType]) -> Result<ArrowDataType> {
(self.return_type_func)(arg_types).map(|x| x.as_ref().clone())
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
(self.accumulator)(acc_args)
}

fn state_fields(
&self,
name: &str,
_value_type: ArrowDataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
self.creator
.state_types()
.map(|x| {
(0..x.len())
.zip(x)
.map(|(i, t)| {
Field::new(format!("{}_{}", name, i), t.as_arrow_type(), true)
})
.collect::<Vec<_>>()
})
.map_err(|e| e.into())
}
}
impl AggregateUDFImpl for DfUdafAdapter {
fn as_any(&self) -> &dyn Any {
self
}

DfUdafAdapter {
fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &datafusion_expr::Signature {
&self.signature
}

fn return_type(&self, arg_types: &[ArrowDataType]) -> Result<ArrowDataType> {
(self.return_type_func)(arg_types).map(|x| x.as_ref().clone())
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
(self.accumulator)(acc_args)
}

fn state_fields(
&self,
name: &str,
_value_type: ArrowDataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
self.creator
.state_types()
.map(|x| {
(0..x.len())
.zip(x)
.map(|(i, t)| Field::new(format!("{}_{}", name, i), t.as_arrow_type(), true))
.collect::<Vec<_>>()
})
.map_err(|e| e.into())
}
}

impl From<AggregateFunction> for DfAggregateUdf {
fn from(udaf: AggregateFunction) -> Self {
DfAggregateUdf::new_from_impl(DfUdafAdapter {
name: udaf.name,
signature: udaf.signature.into(),
return_type_func: to_df_return_type(udaf.return_type),
accumulator: to_df_accumulator_func(udaf.accumulator, udaf.creator.clone()),
creator: udaf.creator,
}
.into()
})
}
}

Expand Down
61 changes: 50 additions & 11 deletions src/common/query/src/logical_plan/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

//! Udf module contains foundational types that are used to represent UDFs.
//! It's modified from datafusion.
use std::any::Any;
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use datafusion_expr::{
ColumnarValue as DfColumnarValue,
ScalarFunctionImplementation as DfScalarFunctionImplementation, ScalarUDF as DfScalarUDF,
ScalarUDFImpl,
};
use datatypes::arrow::datatypes::DataType;

use crate::error::Result;
use crate::function::{ReturnTypeFunction, ScalarFunctionImplementation};
Expand Down Expand Up @@ -68,25 +71,61 @@ impl ScalarUdf {
}
}

#[derive(Clone)]
struct DfUdfAdapter {
name: String,
signature: datafusion_expr::Signature,
return_type: datafusion_expr::ReturnTypeFunction,
fun: DfScalarFunctionImplementation,
}

impl Debug for DfUdfAdapter {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("DfUdfAdapter")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}

impl ScalarUDFImpl for DfUdfAdapter {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &datafusion_expr::Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
(self.return_type)(arg_types).map(|ty| ty.as_ref().clone())
}

fn invoke(&self, args: &[DfColumnarValue]) -> datafusion_common::Result<DfColumnarValue> {
(self.fun)(args)
}
}

impl From<ScalarUdf> for DfScalarUDF {
fn from(udf: ScalarUdf) -> Self {
// TODO(LFC): remove deprecated
#[allow(deprecated)]
DfScalarUDF::new(
&udf.name,
&udf.signature.into(),
&to_df_return_type(udf.return_type),
&to_df_scalar_func(udf.fun),
)
DfScalarUDF::new_from_impl(DfUdfAdapter {
name: udf.name,
signature: udf.signature.into(),
return_type: to_df_return_type(udf.return_type),
fun: to_df_scalar_func(udf.fun),
})
}
}

fn to_df_scalar_func(fun: ScalarFunctionImplementation) -> DfScalarFunctionImplementation {
Arc::new(move |args: &[DfColumnarValue]| {
let args: Result<Vec<_>> = args.iter().map(TryFrom::try_from).collect();

let result = (fun)(&args?);

let result = fun(&args?);
result.map(From::from).map_err(|e| e.into())
})
}

0 comments on commit 1010e6f

Please sign in to comment.