Skip to content

Commit

Permalink
seperate the enum for the unparsing result
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal committed Dec 24, 2024
1 parent 856a5f5 commit 7b6c37f
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 41 deletions.
16 changes: 9 additions & 7 deletions datafusion-examples/examples/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ use datafusion_sql::unparser::ast::{
DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder,
};
use datafusion_sql::unparser::dialect::CustomDialectBuilder;
use datafusion_sql::unparser::extension_unparser::UnparseResult;
use datafusion_sql::unparser::extension_unparser::UserDefinedLogicalNodeUnparser;
use datafusion_sql::unparser::extension_unparser::{
UnparseToStatementResult, UnparseWithinStatementResult,
};
use datafusion_sql::unparser::{plan_to_sql, Unparser};
use std::fmt;
use std::sync::Arc;
Expand Down Expand Up @@ -214,12 +216,12 @@ impl UserDefinedLogicalNodeUnparser for PlanToStatement {
&self,
node: &dyn UserDefinedLogicalNode,
unparser: &Unparser,
) -> Result<UnparseResult> {
) -> Result<UnparseToStatementResult> {
if let Some(plan) = node.as_any().downcast_ref::<MyLogicalPlan>() {
let input = unparser.plan_to_sql(&plan.input)?;
Ok(UnparseResult::Statement(input))
Ok(UnparseToStatementResult::Modified(input))
} else {
Ok(UnparseResult::Original)
Ok(UnparseToStatementResult::Unmodified)
}
}
}
Expand Down Expand Up @@ -261,10 +263,10 @@ impl UserDefinedLogicalNodeUnparser for PlanToSubquery {
_query: &mut Option<&mut QueryBuilder>,
_select: &mut Option<&mut SelectBuilder>,
relation: &mut Option<&mut RelationBuilder>,
) -> Result<UnparseResult> {
) -> Result<UnparseWithinStatementResult> {
if let Some(plan) = node.as_any().downcast_ref::<MyLogicalPlan>() {
let Statement::Query(input) = unparser.plan_to_sql(&plan.input)? else {
return Ok(UnparseResult::Original);
return Ok(UnparseWithinStatementResult::Unmodified);
};
let mut derived_builder = DerivedRelationBuilder::default();
derived_builder.subquery(input);
Expand All @@ -273,7 +275,7 @@ impl UserDefinedLogicalNodeUnparser for PlanToSubquery {
rel.derived(derived_builder);
}
}
Ok(UnparseResult::WithinStatement)
Ok(UnparseWithinStatementResult::Modified)
}
}

Expand Down
26 changes: 16 additions & 10 deletions datafusion/sql/src/unparser/extension_unparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ pub trait UserDefinedLogicalNodeUnparser {
_query: &mut Option<&mut QueryBuilder>,
_select: &mut Option<&mut SelectBuilder>,
_relation: &mut Option<&mut RelationBuilder>,
) -> datafusion_common::Result<UnparseResult> {
Ok(UnparseResult::Original)
) -> datafusion_common::Result<UnparseWithinStatementResult> {
Ok(UnparseWithinStatementResult::Unmodified)
}

/// Unparse the custom logical node to a statement.
Expand All @@ -50,17 +50,23 @@ pub trait UserDefinedLogicalNodeUnparser {
&self,
_node: &dyn UserDefinedLogicalNode,
_unparser: &Unparser,
) -> datafusion_common::Result<UnparseResult> {
Ok(UnparseResult::Original)
) -> datafusion_common::Result<UnparseToStatementResult> {
Ok(UnparseToStatementResult::Unmodified)
}
}

/// The result of unparsing a custom logical node.
pub enum UnparseResult {
/// If the custom logical node was successfully unparsed and return a statement.
Statement(Statement),
/// The result of unparsing a custom logical node within a statement.
pub enum UnparseWithinStatementResult {
/// If the custom logical node was successfully unparsed within a statement.
WithinStatement,
Modified,
/// If the custom logical node wasn't unparsed.
Original,
Unmodified,
}

/// The result of unparsing a custom logical node to a statement.
pub enum UnparseToStatementResult {
/// If the custom logical node was successfully unparsed to a statement.
Modified(Statement),
/// If the custom logical node wasn't unparsed.
Unmodified,
}
22 changes: 7 additions & 15 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ use super::{
Unparser,
};
use crate::unparser::ast::UnnestRelationBuilder;
use crate::unparser::extension_unparser::UnparseResult;
use crate::unparser::extension_unparser::{
UnparseToStatementResult, UnparseWithinStatementResult,
};
use crate::unparser::utils::unproject_agg_exprs;
use crate::utils::UNNEST_PLACEHOLDER;
use datafusion_common::{
Expand Down Expand Up @@ -135,16 +137,11 @@ impl Unparser<'_> {
let mut statement = None;
for unparser in &self.extension_unparsers {
match unparser.unparse_to_statement(node, self)? {
UnparseResult::Statement(stmt) => {
UnparseToStatementResult::Modified(stmt) => {
statement = Some(stmt);
break;
}
UnparseResult::WithinStatement => {
return not_impl_err!(
"UnparseResult::WithinStatement is not supported for `extension_to_statement`"
);
}
UnparseResult::Original => {}
UnparseToStatementResult::Unmodified => {}
}
}
if let Some(statement) = statement {
Expand All @@ -166,13 +163,8 @@ impl Unparser<'_> {
) -> Result<()> {
for unparser in &self.extension_unparsers {
match unparser.unparse(node, self, query, select, relation)? {
UnparseResult::WithinStatement => return Ok(()),
UnparseResult::Original => {}
UnparseResult::Statement(_) => {
return not_impl_err!(
"UnparseResult::Statement is not supported for `extension_to_sql`"
);
}
UnparseWithinStatementResult::Modified => return Ok(()),
UnparseWithinStatementResult::Unmodified => {}
}
}
not_impl_err!("Unsupported extension node: {node:?}")
Expand Down
19 changes: 10 additions & 9 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ use datafusion_sql::unparser::ast::{
DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder,
};
use datafusion_sql::unparser::extension_unparser::{
UnparseResult, UserDefinedLogicalNodeUnparser,
UnparseToStatementResult, UnparseWithinStatementResult,
UserDefinedLogicalNodeUnparser,
};
use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect};
use sqlparser::parser::Parser;
Expand Down Expand Up @@ -1461,12 +1462,12 @@ impl UserDefinedLogicalNodeUnparser for MockStatementUnparser {
&self,
node: &dyn UserDefinedLogicalNode,
unparser: &Unparser,
) -> Result<UnparseResult> {
) -> Result<UnparseToStatementResult> {
if let Some(plan) = node.as_any().downcast_ref::<MockUserDefinedLogicalPlan>() {
let input = unparser.plan_to_sql(&plan.input)?;
Ok(UnparseResult::Statement(input))
Ok(UnparseToStatementResult::Modified(input))
} else {
Ok(UnparseResult::Original)
Ok(UnparseToStatementResult::Unmodified)
}
}
}
Expand All @@ -1481,15 +1482,15 @@ impl UserDefinedLogicalNodeUnparser for UnusedUnparser {
_query: &mut Option<&mut QueryBuilder>,
_select: &mut Option<&mut SelectBuilder>,
_relation: &mut Option<&mut RelationBuilder>,
) -> Result<UnparseResult> {
) -> Result<UnparseWithinStatementResult> {
panic!("This should not be called");
}

fn unparse_to_statement(
&self,
_node: &dyn UserDefinedLogicalNode,
_unparser: &Unparser,
) -> Result<UnparseResult> {
) -> Result<UnparseToStatementResult> {
panic!("This should not be called");
}
}
Expand Down Expand Up @@ -1537,10 +1538,10 @@ impl UserDefinedLogicalNodeUnparser for MockSqlUnparser {
_query: &mut Option<&mut QueryBuilder>,
_select: &mut Option<&mut SelectBuilder>,
relation: &mut Option<&mut RelationBuilder>,
) -> Result<UnparseResult> {
) -> Result<UnparseWithinStatementResult> {
if let Some(plan) = node.as_any().downcast_ref::<MockUserDefinedLogicalPlan>() {
let Statement::Query(input) = unparser.plan_to_sql(&plan.input)? else {
return Ok(UnparseResult::Original);
return Ok(UnparseWithinStatementResult::Unmodified);
};
let mut derived_builder = DerivedRelationBuilder::default();
derived_builder.subquery(input);
Expand All @@ -1549,7 +1550,7 @@ impl UserDefinedLogicalNodeUnparser for MockSqlUnparser {
rel.derived(derived_builder);
}
}
Ok(UnparseResult::WithinStatement)
Ok(UnparseWithinStatementResult::Modified)
}
}

Expand Down

0 comments on commit 7b6c37f

Please sign in to comment.