diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index f20475df150b..c81e0afa1827 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -63,7 +63,7 @@ use datafusion_expr::{ expr_rewriter::FunctionRewrite, logical_plan::{DdlStatement, Statement}, planner::ExprPlanner, - Expr, UserDefinedLogicalNode, WindowUDF, + Expr, WindowUDF, }; // backwards compatibility @@ -1682,27 +1682,7 @@ pub enum RegisterFunction { #[derive(Debug)] pub struct EmptySerializerRegistry; -impl SerializerRegistry for EmptySerializerRegistry { - fn serialize_logical_plan( - &self, - node: &dyn UserDefinedLogicalNode, - ) -> Result> { - not_impl_err!( - "Serializing user defined logical plan node `{}` is not supported", - node.name() - ) - } - - fn deserialize_logical_plan( - &self, - name: &str, - _bytes: &[u8], - ) -> Result> { - not_impl_err!( - "Deserializing user defined logical plan node `{name}` is not supported" - ) - } -} +impl SerializerRegistry for EmptySerializerRegistry {} /// Describes which SQL statements can be run. /// diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 4eb49710bcf8..a2f5a45e7b9b 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -19,7 +19,7 @@ use crate::expr_rewriter::FunctionRewrite; use crate::planner::ExprPlanner; -use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; +use crate::{AggregateUDF, ScalarUDF, TableSource, UserDefinedLogicalNode, WindowUDF}; use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result}; use std::collections::HashSet; use std::fmt::Debug; @@ -123,24 +123,58 @@ pub trait FunctionRegistry { } } -/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. +/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode] +/// and custom table providers for which the name alone is meaningless in the target +/// execution context, e.g. UDTFs, manually registered tables etc. pub trait SerializerRegistry: Debug + Send + Sync { /// Serialize this node to a byte array. This serialization should not include /// input plans. fn serialize_logical_plan( &self, node: &dyn UserDefinedLogicalNode, - ) -> Result>; + ) -> Result { + not_impl_err!( + "Serializing user defined logical plan node `{}` is not supported", + node.name() + ) + } /// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from /// bytes. fn deserialize_logical_plan( &self, name: &str, - bytes: &[u8], - ) -> Result>; + _bytes: &[u8], + ) -> Result> { + not_impl_err!( + "Deserializing user defined logical plan node `{name}` is not supported" + ) + } + + /// Serialized table definition for UDTFs or some other table provider implementation that + /// can't be marshaled by reference. + fn serialize_custom_table( + &self, + _table: &dyn TableSource, + ) -> Result> { + Ok(None) + } + + /// Deserialize a custom table. + fn deserialize_custom_table( + &self, + name: &str, + _bytes: &[u8], + ) -> Result> { + not_impl_err!("Deserializing custom table `{name}` is not supported") + } } +/// A sequence of bytes with a string qualifier. Meant to encapsulate serialized extensions +/// that need to carry their type, e.g. the `type_url` for protobuf messages. +#[derive(Debug, Clone)] +pub struct NamedBytes(pub String, pub Vec); + /// A [`FunctionRegistry`] that uses in memory [`HashMap`]s #[derive(Default, Debug)] pub struct MemoryFunctionRegistry { diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 9623f12c88dd..6a7857bc0a1e 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -30,7 +30,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension, - LogicalPlan, Operator, Projection, SortExpr, Subquery, TryCast, Values, + LogicalPlan, Operator, Projection, SortExpr, Subquery, TableSource, TryCast, Values, }; use substrait::proto::aggregate_rel::Grouping; use substrait::proto::expression as substrait_expression; @@ -86,6 +86,7 @@ use substrait::proto::expression::{ SingularOrList, SwitchExpression, WindowFunction, }; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; +use substrait::proto::read_rel::ExtensionTable; use substrait::proto::rel_common::{Emit, EmitKind}; use substrait::proto::set_rel::SetOp; use substrait::proto::{ @@ -457,6 +458,20 @@ pub trait SubstraitConsumer: Send + Sync + Sized { user_defined_literal.type_reference ) } + + fn consume_extension_table( + &self, + extension_table: &ExtensionTable, + ) -> Result> { + if let Some(ext_detail) = extension_table.detail.as_ref() { + substrait_err!( + "Missing handler for extension table: {}", + &ext_detail.type_url + ) + } else { + substrait_err!("Unexpected empty detail in ExtensionTable") + } + } } /// Convert Substrait Rel to DataFusion DataFrame @@ -578,6 +593,19 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; Ok(LogicalPlan::Extension(Extension { node: plan })) } + + fn consume_extension_table( + &self, + extension_table: &ExtensionTable, + ) -> Result> { + if let Some(ext_detail) = &extension_table.detail { + self.state + .serializer_registry() + .deserialize_custom_table(&ext_detail.type_url, &ext_detail.value) + } else { + substrait_err!("Unexpected empty detail in ExtensionTable") + } + } } // Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which @@ -1323,26 +1351,14 @@ pub async fn from_read_rel( read: &ReadRel, ) -> Result { async fn read_with_schema( - consumer: &impl SubstraitConsumer, table_ref: TableReference, + table_source: Arc, schema: DFSchema, projection: &Option, ) -> Result { let schema = schema.replace_qualifier(table_ref.clone()); - let plan = { - let provider = match consumer.resolve_table_ref(&table_ref).await? { - Some(ref provider) => Arc::clone(provider), - _ => return plan_err!("No table named '{table_ref}'"), - }; - - LogicalPlanBuilder::scan( - table_ref, - provider_as_source(Arc::clone(&provider)), - None, - )? - .build()? - }; + let plan = { LogicalPlanBuilder::scan(table_ref, table_source, None)?.build()? }; ensure_schema_compatibility(plan.schema(), schema.clone())?; @@ -1351,6 +1367,17 @@ pub async fn from_read_rel( apply_projection(plan, schema) } + async fn table_source( + consumer: &impl SubstraitConsumer, + table_ref: &TableReference, + ) -> Result> { + if let Some(provider) = consumer.resolve_table_ref(table_ref).await? { + Ok(provider_as_source(provider)) + } else { + plan_err!("No table named '{table_ref}'") + } + } + let named_struct = read.base_schema.as_ref().ok_or_else(|| { substrait_datafusion_err!("No base schema provided for Read Relation") })?; @@ -1376,10 +1403,10 @@ pub async fn from_read_rel( table: nt.names[2].clone().into(), }, }; - + let table_source = table_source(consumer, &table_reference).await?; read_with_schema( - consumer, table_reference, + table_source, substrait_schema, &read.projection, ) @@ -1458,17 +1485,38 @@ pub async fn from_read_rel( let name = filename.unwrap(); // directly use unwrap here since we could determine it is a valid one let table_reference = TableReference::Bare { table: name.into() }; + let table_source = table_source(consumer, &table_reference).await?; read_with_schema( - consumer, table_reference, + table_source, + substrait_schema, + &read.projection, + ) + .await + } + Some(ReadType::ExtensionTable(ext)) => { + // look for the original table name under `rel.common.hint.alias` + // in case the producer was kind enough to put it there. + let name_hint = read + .common + .as_ref() + .and_then(|rel_common| rel_common.hint.as_ref()) + .map(|hint| hint.alias.as_str().trim()) + .filter(|alias| !alias.is_empty()); + // if no name hint was provided, use the name that datafusion + // sets for UDTFs + let table_name = name_hint.unwrap_or("tmp_table"); + read_with_schema( + TableReference::from(table_name), + consumer.consume_extension_table(ext)?, substrait_schema, &read.projection, ) .await } - _ => { - not_impl_err!("Unsupported ReadType: {:?}", read.read_type) + None => { + substrait_err!("Unexpected empty read_type") } } } @@ -1871,7 +1919,7 @@ pub async fn from_substrait_sorts( }, None => not_impl_err!("Sort without sort kind is invalid"), }; - let (asc, nulls_first) = asc_nullfirst.unwrap(); + let (asc, nulls_first) = asc_nullfirst?; sorts.push(Sort { expr, asc, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index e501ddf5c698..ac0b101f4225 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -25,7 +25,7 @@ use datafusion::arrow::datatypes::{Field, IntervalUnit}; use datafusion::logical_expr::{ Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit, Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan, - TryCast, Union, Values, Window, WindowFrameUnits, + TableSource, TryCast, Union, Values, Window, WindowFrameUnits, }; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, @@ -54,6 +54,7 @@ use datafusion::execution::SessionState; use datafusion::logical_expr::expr::{ Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, WindowFunction, }; +use datafusion::logical_expr::registry::NamedBytes; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; use pbjson_types::Any as ProtoAny; @@ -69,9 +70,9 @@ use substrait::proto::expression::literal::{ use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; use substrait::proto::expression::ScalarFunction; -use substrait::proto::read_rel::VirtualTable; -use substrait::proto::rel_common::EmitKind; +use substrait::proto::read_rel::{ExtensionTable, VirtualTable}; use substrait::proto::rel_common::EmitKind::Emit; +use substrait::proto::rel_common::{EmitKind, Hint}; use substrait::proto::{ fetch_rel, rel_common, ExchangeRel, ExpressionReference, ExtendedExpression, RelCommon, @@ -366,6 +367,13 @@ pub trait SubstraitProducer: Send + Sync + Sized { ) -> Result { from_in_subquery(self, in_subquery, schema) } + + fn handle_custom_table( + &mut self, + _table: &dyn TableSource, + ) -> Result> { + not_impl_err!("Not implemented") + } } struct DefaultSubstraitProducer<'a> { @@ -392,12 +400,12 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> { } fn handle_extension(&mut self, plan: &Extension) -> Result> { - let extension_bytes = self + let NamedBytes(type_url, bytes) = self .serializer_registry .serialize_logical_plan(plan.node.as_ref())?; let detail = ProtoAny { - type_url: plan.node.name().to_string(), - value: extension_bytes.into(), + type_url, + value: bytes.into(), }; let mut inputs_rel = plan .node @@ -425,6 +433,24 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> { rel_type: Some(rel_type), })) } + + fn handle_custom_table( + &mut self, + table: &dyn TableSource, + ) -> Result> { + if let Some(NamedBytes(type_url, bytes)) = + self.serializer_registry.serialize_custom_table(table)? + { + Ok(Some(ExtensionTable { + detail: Some(ProtoAny { + type_url, + value: bytes.into(), + }), + })) + } else { + Ok(None) + } + } } /// Convert DataFusion LogicalPlan to Substrait Plan @@ -539,7 +565,7 @@ pub fn to_substrait_rel( } pub fn from_table_scan( - _producer: &mut impl SubstraitProducer, + producer: &mut impl SubstraitProducer, scan: &TableScan, ) -> Result> { let projection = scan.projection.as_ref().map(|p| { @@ -559,18 +585,38 @@ pub fn from_table_scan( let table_schema = scan.source.schema().to_dfschema_ref()?; let base_schema = to_substrait_named_struct(&table_schema)?; + let (table, common) = + if let Ok(Some(ext_table)) = producer.handle_custom_table(scan.source.as_ref()) { + ( + ReadType::ExtensionTable(ext_table), + Some(RelCommon { + hint: Some(Hint { + // store the original table name as rel.common.hint.alias + alias: scan.table_name.to_string(), + ..Default::default() + }), + ..Default::default() + }), + ) + } else { + ( + ReadType::NamedTable(NamedTable { + names: scan.table_name.to_vec(), + advanced_extension: None, + }), + None, + ) + }; + Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, + common, base_schema: Some(base_schema), filter: None, best_effort_filter: None, projection, advanced_extension: None, - read_type: Some(ReadType::NamedTable(NamedTable { - names: scan.table_name.to_vec(), - advanced_extension: None, - })), + read_type: Some(table), }))), })) } @@ -1693,7 +1739,7 @@ pub fn from_in_subquery( subquery_type: Some( substrait::proto::expression::subquery::SubqueryType::InPredicate( Box::new(InPredicate { - needles: (vec![substrait_expr]), + needles: vec![substrait_expr], haystack: Some(subquery_plan), }), ), @@ -2532,8 +2578,8 @@ mod test { use super::*; use crate::logical_plan::consumer::{ from_substrait_extended_expr, from_substrait_literal_without_names, - from_substrait_named_struct, from_substrait_type_without_names, - DefaultSubstraitConsumer, + from_substrait_named_struct, from_substrait_plan, + from_substrait_type_without_names, DefaultSubstraitConsumer, }; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::arrow::array::{ @@ -2541,8 +2587,12 @@ mod test { }; use datafusion::arrow::datatypes::{Field, Fields, Schema}; use datafusion::common::scalar::ScalarStructBuilder; - use datafusion::common::DFSchema; + use datafusion::common::{assert_contains, DFSchema}; + use datafusion::datasource::empty::EmptyTable; + use datafusion::datasource::{DefaultTableSource, TableProvider}; use datafusion::execution::{SessionState, SessionStateBuilder}; + use datafusion::logical_expr::registry::SerializerRegistry; + use datafusion::logical_expr::TableSource; use datafusion::prelude::SessionContext; use std::sync::OnceLock; @@ -2879,4 +2929,114 @@ mod test { assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); } + + #[tokio::test] + async fn round_trip_extension_table() { + const TABLE_NAME: &str = "custom_table"; + const TYPE_URL: &str = "/substrait.test.CustomTable"; + const SERIALIZED: &[u8] = "table definition".as_bytes(); + + fn custom_table() -> Arc { + Arc::new(EmptyTable::new(Arc::new(Schema::new([ + Arc::new(Field::new("id", DataType::Int32, false)), + Arc::new(Field::new("name", DataType::Utf8, false)), + ])))) + } + + #[derive(Debug)] + struct Registry; + impl SerializerRegistry for Registry { + fn serialize_custom_table( + &self, + table: &dyn TableSource, + ) -> Result> { + if table.schema() == custom_table().schema() { + Ok(Some(NamedBytes(TYPE_URL.to_string(), SERIALIZED.to_vec()))) + } else { + Err(DataFusionError::Internal("Not our table".into())) + } + } + fn deserialize_custom_table( + &self, + name: &str, + bytes: &[u8], + ) -> Result> { + if name == TYPE_URL && bytes == SERIALIZED { + Ok(Arc::new(DefaultTableSource::new(custom_table()))) + } else { + panic!("Unexpected extension table: {name}"); + } + } + } + + async fn round_trip_logical_plans( + local: &SessionContext, + remote: &SessionContext, + ) -> Result<()> { + local.register_table(TABLE_NAME, custom_table())?; + remote.table_provider(TABLE_NAME).await.expect_err( + "The remote context is not supposed to know about custom_table", + ); + let initial_plan = local + .sql(&format!("select id from {TABLE_NAME}")) + .await? + .logical_plan() + .clone(); + + // write substrait locally + let substrait = to_substrait_plan(&initial_plan, &local.state())?; + + // read substrait remotely + // since we know there's no `custom_table` registered in the remote context, this will only succeed + // if our table got encoded as an ExtensionTable and is now decoded back to a table source. + let restored = from_substrait_plan(&remote.state(), &substrait).await?; + assert_contains!( + // confirm that the Substrait plan contains our custom_table as an ExtensionTable + serde_json::to_string(substrait.as_ref()).unwrap(), + format!(r#""extensionTable":{{"detail":{{"typeUrl":"{TYPE_URL}","#) + ); + remote // make sure the restored plan is fully working in the remote context + .execute_logical_plan(restored.clone()) + .await? + .collect() + .await + .expect("Restored plan cannot be executed remotely"); + assert_eq!( + // check that the restored plan is functionally equivalent (and almost identical) to the initial one + initial_plan.to_string(), + restored.to_string().replace( + // substrait will add an explicit full-schema projection if the original table had none + &format!("TableScan: {TABLE_NAME} projection=[id, name]"), + &format!("TableScan: {TABLE_NAME}"), + ) + ); + Ok(()) + } + + // take 1 + let failed_attempt = + round_trip_logical_plans(&SessionContext::new(), &SessionContext::new()) + .await + .expect_err( + "The round trip should fail in the absence of a SerializerRegistry", + ); + assert_contains!( + failed_attempt.message(), + format!("No table named '{TABLE_NAME}'") + ); + + // take 2 + fn proper_context() -> SessionContext { + SessionContext::new_with_state( + SessionStateBuilder::new() + // This will transport our custom_table as a Substrait ExtensionTable + .with_serializer_registry(Arc::new(Registry)) + .build(), + ) + } + + round_trip_logical_plans(&proper_context(), &proper_context()) + .await + .expect("Local plan could not be restored remotely"); + } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 7045729493b1..0a9c8e525745 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -31,6 +31,7 @@ use datafusion::error::Result; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::session_state::SessionStateBuilder; +use datafusion::logical_expr::registry::NamedBytes; use datafusion::logical_expr::{ Extension, LogicalPlan, PartitionEvaluator, Repartition, UserDefinedLogicalNode, Values, Volatility, @@ -50,13 +51,13 @@ impl SerializerRegistry for MockSerializerRegistry { fn serialize_logical_plan( &self, node: &dyn UserDefinedLogicalNode, - ) -> Result> { + ) -> Result { if node.name() == "MockUserDefinedLogicalPlan" { let node = node .as_any() .downcast_ref::() .unwrap(); - node.serialize() + Ok(NamedBytes(node.name().to_string(), node.serialize()?)) } else { unreachable!() }