From 7b6c37fdd934dec301a4036a72cd5c206220a9bc Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 24 Dec 2024 19:18:34 +0800 Subject: [PATCH] seperate the enum for the unparsing result --- datafusion-examples/examples/plan_to_sql.rs | 16 +++++++----- .../sql/src/unparser/extension_unparser.rs | 26 ++++++++++++------- datafusion/sql/src/unparser/plan.rs | 22 +++++----------- datafusion/sql/tests/cases/plan_to_sql.rs | 19 +++++++------- 4 files changed, 42 insertions(+), 41 deletions(-) diff --git a/datafusion-examples/examples/plan_to_sql.rs b/datafusion-examples/examples/plan_to_sql.rs index 2e176199d74c..43a7f19dc6c9 100644 --- a/datafusion-examples/examples/plan_to_sql.rs +++ b/datafusion-examples/examples/plan_to_sql.rs @@ -28,8 +28,10 @@ use datafusion_sql::unparser::ast::{ DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder, }; use datafusion_sql::unparser::dialect::CustomDialectBuilder; -use datafusion_sql::unparser::extension_unparser::UnparseResult; use datafusion_sql::unparser::extension_unparser::UserDefinedLogicalNodeUnparser; +use datafusion_sql::unparser::extension_unparser::{ + UnparseToStatementResult, UnparseWithinStatementResult, +}; use datafusion_sql::unparser::{plan_to_sql, Unparser}; use std::fmt; use std::sync::Arc; @@ -214,12 +216,12 @@ impl UserDefinedLogicalNodeUnparser for PlanToStatement { &self, node: &dyn UserDefinedLogicalNode, unparser: &Unparser, - ) -> Result { + ) -> Result { if let Some(plan) = node.as_any().downcast_ref::() { let input = unparser.plan_to_sql(&plan.input)?; - Ok(UnparseResult::Statement(input)) + Ok(UnparseToStatementResult::Modified(input)) } else { - Ok(UnparseResult::Original) + Ok(UnparseToStatementResult::Unmodified) } } } @@ -261,10 +263,10 @@ impl UserDefinedLogicalNodeUnparser for PlanToSubquery { _query: &mut Option<&mut QueryBuilder>, _select: &mut Option<&mut SelectBuilder>, relation: &mut Option<&mut RelationBuilder>, - ) -> Result { + ) -> Result { if let Some(plan) = node.as_any().downcast_ref::() { let Statement::Query(input) = unparser.plan_to_sql(&plan.input)? else { - return Ok(UnparseResult::Original); + return Ok(UnparseWithinStatementResult::Unmodified); }; let mut derived_builder = DerivedRelationBuilder::default(); derived_builder.subquery(input); @@ -273,7 +275,7 @@ impl UserDefinedLogicalNodeUnparser for PlanToSubquery { rel.derived(derived_builder); } } - Ok(UnparseResult::WithinStatement) + Ok(UnparseWithinStatementResult::Modified) } } diff --git a/datafusion/sql/src/unparser/extension_unparser.rs b/datafusion/sql/src/unparser/extension_unparser.rs index 3ad224422ae6..d3161ced7b4c 100644 --- a/datafusion/sql/src/unparser/extension_unparser.rs +++ b/datafusion/sql/src/unparser/extension_unparser.rs @@ -36,8 +36,8 @@ pub trait UserDefinedLogicalNodeUnparser { _query: &mut Option<&mut QueryBuilder>, _select: &mut Option<&mut SelectBuilder>, _relation: &mut Option<&mut RelationBuilder>, - ) -> datafusion_common::Result { - Ok(UnparseResult::Original) + ) -> datafusion_common::Result { + Ok(UnparseWithinStatementResult::Unmodified) } /// Unparse the custom logical node to a statement. @@ -50,17 +50,23 @@ pub trait UserDefinedLogicalNodeUnparser { &self, _node: &dyn UserDefinedLogicalNode, _unparser: &Unparser, - ) -> datafusion_common::Result { - Ok(UnparseResult::Original) + ) -> datafusion_common::Result { + Ok(UnparseToStatementResult::Unmodified) } } -/// The result of unparsing a custom logical node. -pub enum UnparseResult { - /// If the custom logical node was successfully unparsed and return a statement. - Statement(Statement), +/// The result of unparsing a custom logical node within a statement. +pub enum UnparseWithinStatementResult { /// If the custom logical node was successfully unparsed within a statement. - WithinStatement, + Modified, /// If the custom logical node wasn't unparsed. - Original, + Unmodified, +} + +/// The result of unparsing a custom logical node to a statement. +pub enum UnparseToStatementResult { + /// If the custom logical node was successfully unparsed to a statement. + Modified(Statement), + /// If the custom logical node wasn't unparsed. + Unmodified, } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index d282100a3a0f..3dcf0f66747c 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -33,7 +33,9 @@ use super::{ Unparser, }; use crate::unparser::ast::UnnestRelationBuilder; -use crate::unparser::extension_unparser::UnparseResult; +use crate::unparser::extension_unparser::{ + UnparseToStatementResult, UnparseWithinStatementResult, +}; use crate::unparser::utils::unproject_agg_exprs; use crate::utils::UNNEST_PLACEHOLDER; use datafusion_common::{ @@ -135,16 +137,11 @@ impl Unparser<'_> { let mut statement = None; for unparser in &self.extension_unparsers { match unparser.unparse_to_statement(node, self)? { - UnparseResult::Statement(stmt) => { + UnparseToStatementResult::Modified(stmt) => { statement = Some(stmt); break; } - UnparseResult::WithinStatement => { - return not_impl_err!( - "UnparseResult::WithinStatement is not supported for `extension_to_statement`" - ); - } - UnparseResult::Original => {} + UnparseToStatementResult::Unmodified => {} } } if let Some(statement) = statement { @@ -166,13 +163,8 @@ impl Unparser<'_> { ) -> Result<()> { for unparser in &self.extension_unparsers { match unparser.unparse(node, self, query, select, relation)? { - UnparseResult::WithinStatement => return Ok(()), - UnparseResult::Original => {} - UnparseResult::Statement(_) => { - return not_impl_err!( - "UnparseResult::Statement is not supported for `extension_to_sql`" - ); - } + UnparseWithinStatementResult::Modified => return Ok(()), + UnparseWithinStatementResult::Unmodified => {} } } not_impl_err!("Unsupported extension node: {node:?}") diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index d3f6397fef51..3fdd4f74a0c2 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -51,7 +51,8 @@ use datafusion_sql::unparser::ast::{ DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder, }; use datafusion_sql::unparser::extension_unparser::{ - UnparseResult, UserDefinedLogicalNodeUnparser, + UnparseToStatementResult, UnparseWithinStatementResult, + UserDefinedLogicalNodeUnparser, }; use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; use sqlparser::parser::Parser; @@ -1461,12 +1462,12 @@ impl UserDefinedLogicalNodeUnparser for MockStatementUnparser { &self, node: &dyn UserDefinedLogicalNode, unparser: &Unparser, - ) -> Result { + ) -> Result { if let Some(plan) = node.as_any().downcast_ref::() { let input = unparser.plan_to_sql(&plan.input)?; - Ok(UnparseResult::Statement(input)) + Ok(UnparseToStatementResult::Modified(input)) } else { - Ok(UnparseResult::Original) + Ok(UnparseToStatementResult::Unmodified) } } } @@ -1481,7 +1482,7 @@ impl UserDefinedLogicalNodeUnparser for UnusedUnparser { _query: &mut Option<&mut QueryBuilder>, _select: &mut Option<&mut SelectBuilder>, _relation: &mut Option<&mut RelationBuilder>, - ) -> Result { + ) -> Result { panic!("This should not be called"); } @@ -1489,7 +1490,7 @@ impl UserDefinedLogicalNodeUnparser for UnusedUnparser { &self, _node: &dyn UserDefinedLogicalNode, _unparser: &Unparser, - ) -> Result { + ) -> Result { panic!("This should not be called"); } } @@ -1537,10 +1538,10 @@ impl UserDefinedLogicalNodeUnparser for MockSqlUnparser { _query: &mut Option<&mut QueryBuilder>, _select: &mut Option<&mut SelectBuilder>, relation: &mut Option<&mut RelationBuilder>, - ) -> Result { + ) -> Result { if let Some(plan) = node.as_any().downcast_ref::() { let Statement::Query(input) = unparser.plan_to_sql(&plan.input)? else { - return Ok(UnparseResult::Original); + return Ok(UnparseWithinStatementResult::Unmodified); }; let mut derived_builder = DerivedRelationBuilder::default(); derived_builder.subquery(input); @@ -1549,7 +1550,7 @@ impl UserDefinedLogicalNodeUnparser for MockSqlUnparser { rel.derived(derived_builder); } } - Ok(UnparseResult::WithinStatement) + Ok(UnparseWithinStatementResult::Modified) } }