diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 1f45bc3e83e1..57c8bc9681cc 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -361,14 +361,14 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Sort(sort) => { - let input = to_substrait_rel(sort.input.as_ref(), state, extensions)?; - let sort_fields = sort - .expr + LogicalPlan::Sort(datafusion::logical_expr::Sort { expr, input, fetch }) => { + let sort_fields = expr .iter() - .map(|e| substrait_sort_field(state, e, sort.input.schema(), extensions)) + .map(|e| substrait_sort_field(state, e, input.schema(), extensions)) .collect::>>()?; + let input = to_substrait_rel(input.as_ref(), state, extensions)?; + let sort_rel = Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { common: None, @@ -378,23 +378,16 @@ pub fn to_substrait_rel( }))), }); - match sort.fetch { - Some(_) => { - let empty_schema = Arc::new(DFSchema::empty()); - let count_mode = sort - .fetch - .map(|amount| { - to_substrait_rex( - state, - &Expr::Literal(ScalarValue::Int64(Some(amount as i64))), - &empty_schema, - 0, - extensions, - ) - }) - .transpose()? - .map(Box::new) - .map(fetch_rel::CountMode::CountExpr); + match fetch { + Some(amount) => { + let count_mode = + Some(fetch_rel::CountMode::CountExpr(Box::new(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: false, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::I64(*amount as i64)), + })), + }))); Ok(Box::new(Rel { rel_type: Some(RelType::Fetch(Box::new(FetchRel { common: None,