Skip to content

Commit

Permalink
use MedianGroupsAccumulator.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rachelint committed Dec 8, 2024
1 parent df685a5 commit ded4e5f
Showing 1 changed file with 40 additions and 1 deletion.
41 changes: 40 additions & 1 deletion datafusion/functions-aggregate/src/median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use arrow::array::Array;
use arrow::array::ArrowNativeTypeOp;
use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType};

use datafusion_common::{DataFusionError, HashSet, Result, ScalarValue};
use datafusion_common::{internal_err, DataFusionError, HashSet, Result, ScalarValue};
use datafusion_doc::DocSection;
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::{
Expand Down Expand Up @@ -173,6 +173,45 @@ impl AggregateUDFImpl for Median {
}
}

fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
!args.is_distinct
}

fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
let num_args = args.exprs.len();
if num_args != 1 {
return internal_err!(
"median should only have 1 arg, but found num args:{}",
args.exprs.len()
);
}

let dt = args.exprs[0].data_type(args.schema)?;

macro_rules! helper {
($t:ty, $dt:expr) => {
Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
};
}

downcast_integer! {
dt => (helper, dt),
DataType::Float16 => helper!(Float16Type, dt),
DataType::Float32 => helper!(Float32Type, dt),
DataType::Float64 => helper!(Float64Type, dt),
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
_ => Err(DataFusionError::NotImplemented(format!(
"MedianGroupsAccumulator not supported for {} with {}",
args.name,
dt,
))),
}
}

fn aliases(&self) -> &[String] {
&[]
}
Expand Down

0 comments on commit ded4e5f

Please sign in to comment.