diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 4eb49710bcf85..f3ac906712029 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -19,8 +19,10 @@ use crate::expr_rewriter::FunctionRewrite; use crate::planner::ExprPlanner; -use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; -use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result}; +use crate::{AggregateUDF, ScalarUDF, TableSource, UserDefinedLogicalNode, WindowUDF}; +use datafusion_common::{ + not_impl_err, plan_datafusion_err, DataFusionError, HashMap, Result, +}; use std::collections::HashSet; use std::fmt::Debug; use std::sync::Arc; @@ -123,22 +125,62 @@ pub trait FunctionRegistry { } } -/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. +/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode] +/// and custom table providers for which the name alone is meaningless in the target +/// execution context, e.g. UDTFs, manually registered tables etc. pub trait SerializerRegistry: Debug + Send + Sync { /// Serialize this node to a byte array. This serialization should not include /// input plans. fn serialize_logical_plan( &self, - node: &dyn UserDefinedLogicalNode, - ) -> Result>; + _node: &dyn UserDefinedLogicalNode, + ) -> Result> { + Err(DataFusionError::Plan( + "UserDefinedLogicalNode serialization not supported".into(), + )) + } /// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from /// bytes. fn deserialize_logical_plan( &self, - name: &str, - bytes: &[u8], - ) -> Result>; + _name: &str, + _bytes: &[u8], + ) -> Result> { + Err(DataFusionError::Plan( + "UserDefinedLogicalNode deserialization not supported".into(), + )) + } + + /// Binary representation for custom tables, to be converted to substrait extension tables. + /// Should only return success for table implementations that cannot be found by name + /// in the destination execution context, such as UDTFs or manually registered table providers. + fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result> { + Err(DataFusionError::Plan( + "Custom table serialization not supported".into(), + )) + } + + /// Deserialize the custom table with the given name. + /// The name may not be useful as a discriminator if multiple UDTF/TableProvider + /// implementations are expected. This is particularly true for UDTFs in DataFusion, + /// which are always registered under the same name: `tmp_table`, so one should + /// use the binary payload to distinguish between multiple potential table types. + /// A potential future improvement would be to return a (name, bytes) tuple from + /// [SerializerRegistry::serialize_custom_table] to allow the implementors to assign + /// different names to different table provider implementations (e.g. in the case of proto, + /// by using the actual protobuf `type_url`). + /// But this would mean the table names in the restored plan may no longer match + /// the original ones. + fn deserialize_custom_table( + &self, + _name: &str, + _bytes: &[u8], + ) -> Result> { + Err(DataFusionError::Plan( + "Custom table deserialization not supported".into(), + )) + } } /// A [`FunctionRegistry`] that uses in memory [`HashMap`]s diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index a9e411e35ae88..963c8d9642671 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -31,7 +31,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan, - Operator, Projection, SortExpr, TryCast, Values, + Operator, Projection, SortExpr, TableScan, TryCast, Values, }; use substrait::proto::aggregate_rel::Grouping; use substrait::proto::expression::subquery::set_predicate::PredicateOp; @@ -994,8 +994,36 @@ pub async fn from_substrait_rel( ) .await } - _ => { - not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type) + Some(ReadType::ExtensionTable(ext)) => { + if let Some(ext_detail) = &ext.detail { + let source = + state.serializer_registry().deserialize_custom_table( + &ext_detail.type_url, + &ext_detail.value, + )?; + let table_name = if let Some((_, name)) = + ext_detail.type_url.rsplit_once('/') + { + name + } else { + &ext_detail.type_url + }; + let plan = LogicalPlan::TableScan(TableScan::try_new( + table_name, + source, + None, + vec![], + None, + )?); + let schema = apply_masking(substrait_schema, &read.projection)?; + ensure_schema_compatability(plan.schema(), schema.clone())?; + apply_projection(plan, schema) + } else { + substrait_err!("Unexpected empty detail in ExtensionTable") + } + } + None => { + substrait_err!("Unexpected empty read_type") } } } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 26d71c7fd3e24..0779956561172 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -64,7 +64,7 @@ use substrait::proto::expression::literal::{ }; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; -use substrait::proto::read_rel::VirtualTable; +use substrait::proto::read_rel::{ExtensionTable, VirtualTable}; use substrait::proto::rel_common::EmitKind; use substrait::proto::rel_common::EmitKind::Emit; use substrait::proto::{ @@ -211,6 +211,23 @@ pub fn to_substrait_rel( let table_schema = scan.source.schema().to_dfschema_ref()?; let base_schema = to_substrait_named_struct(&table_schema)?; + let table = if let Ok(bytes) = state + .serializer_registry() + .serialize_custom_table(scan.source.as_ref()) + { + ReadType::ExtensionTable(ExtensionTable { + detail: Some(ProtoAny { + type_url: scan.table_name.to_string(), + value: bytes.into(), + }), + }) + } else { + ReadType::NamedTable(NamedTable { + names: scan.table_name.to_vec(), + advanced_extension: None, + }) + }; + Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, @@ -219,10 +236,7 @@ pub fn to_substrait_rel( best_effort_filter: None, projection, advanced_extension: None, - read_type: Some(ReadType::NamedTable(NamedTable { - names: scan.table_name.to_vec(), - advanced_extension: None, - })), + read_type: Some(table), }))), })) } @@ -2202,7 +2216,8 @@ mod test { use super::*; use crate::logical_plan::consumer::{ from_substrait_extended_expr, from_substrait_literal_without_names, - from_substrait_named_struct, from_substrait_type_without_names, + from_substrait_named_struct, from_substrait_plan, + from_substrait_type_without_names, }; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::arrow::array::{ @@ -2210,8 +2225,13 @@ mod test { }; use datafusion::arrow::datatypes::{Field, Fields, Schema}; use datafusion::common::scalar::ScalarStructBuilder; - use datafusion::common::DFSchema; - use datafusion::execution::SessionStateBuilder; + use datafusion::common::{assert_contains, DFSchema}; + use datafusion::datasource::empty::EmptyTable; + use datafusion::datasource::{DefaultTableSource, TableProvider}; + use datafusion::execution::registry::SerializerRegistry; + use datafusion::execution::{SessionState, SessionStateBuilder}; + use datafusion::logical_expr::TableSource; + use datafusion::prelude::SessionContext; #[test] fn round_trip_literals() -> Result<()> { @@ -2518,4 +2538,106 @@ mod test { assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); } + + #[tokio::test] + async fn round_trip_extension_table() -> Result<()> { + #[derive(Debug)] + struct Registry { + table: Arc, + } + impl SerializerRegistry for Registry { + fn serialize_custom_table( + &self, + _table: &dyn TableSource, + ) -> Result> { + Ok("expected payload".as_bytes().to_vec()) + } + fn deserialize_custom_table( + &self, + _name: &str, + _bytes: &[u8], + ) -> Result> { + Ok(Arc::new(DefaultTableSource::new(self.table.clone()))) + } + } + + async fn round_trip_logical_plans( + local: &SessionContext, + remote: &SessionContext, + table: Arc, + ) -> Result<(LogicalPlan, LogicalPlan)> { + local.register_table("custom_table", table)?; + let initial_plan = local + .sql("select id from custom_table") + .await? + .logical_plan() + .clone(); + + // write substrait locally + let substrait = to_substrait_plan(&initial_plan, &local.state())?; + + // read substrait remotely + // this will only succeed if our custom_table was encoded as an ExtensionTable, + // since there's no `custom_table` registered in the remote context. + let restored = from_substrait_plan(&remote.state(), &substrait).await?; + assert_contains!( + serde_json::to_string(substrait.as_ref()).unwrap(), + // value == base64("expected payload") + r#""extensionTable":{"detail":{"typeUrl":"custom_table","value":"ZXhwZWN0ZWQgcGF5bG9hZA=="}}"# + ); + Ok((initial_plan, restored)) + } + + let empty = Arc::new(EmptyTable::new(Arc::new(Schema::new([ + Arc::new(Field::new("id", DataType::Int32, false)), + Arc::new(Field::new("name", DataType::Utf8, false)), + ])))); + + let first_attempt = round_trip_logical_plans( + &SessionContext::new(), + &SessionContext::new(), + empty.clone(), + ) + .await; + assert_eq!( + first_attempt.unwrap_err().to_string(), + "Error during planning: No table named 'custom_table'" + ); + fn proper_state(table: Arc) -> SessionState { + SessionStateBuilder::new() + .with_default_features() + .with_serializer_registry(Arc::new(Registry { table })) + .build() + } + let local = SessionContext::new_with_state(proper_state(empty.clone())); + let remote = SessionContext::new_with_state(proper_state(empty.clone())); + + let (initial_plan, restored) = round_trip_logical_plans(&local, &remote, empty) + .await + .expect("Should restore the substrait plan as datafusion logical plan"); + + assert_eq!( + initial_plan.to_string(), + restored + .to_string() + // substrait will add an explicit projection with the full schema + .replace( + "TableScan: custom_table projection=[id, name]", + "TableScan: custom_table" + ) + ); + assert_eq!( + local + .execute_logical_plan(initial_plan) + .await? + .collect() + .await?, + remote + .execute_logical_plan(restored) + .await? + .collect() + .await?, + ); + Ok(()) + } }