Skip to content

Commit

Permalink
[substrait] Add support for ExtensionTable
Browse files Browse the repository at this point in the history
  • Loading branch information
ccciudatu committed Dec 24, 2024
1 parent 6cfd1cf commit bfda3d2
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 38 deletions.
24 changes: 2 additions & 22 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ use datafusion_expr::{
expr_rewriter::FunctionRewrite,
logical_plan::{DdlStatement, Statement},
planner::ExprPlanner,
Expr, UserDefinedLogicalNode, WindowUDF,
Expr, WindowUDF,
};

// backwards compatibility
Expand Down Expand Up @@ -1679,27 +1679,7 @@ pub enum RegisterFunction {
#[derive(Debug)]
pub struct EmptySerializerRegistry;

impl SerializerRegistry for EmptySerializerRegistry {
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>> {
not_impl_err!(
"Serializing user defined logical plan node `{}` is not supported",
node.name()
)
}

fn deserialize_logical_plan(
&self,
name: &str,
_bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
not_impl_err!(
"Deserializing user defined logical plan node `{name}` is not supported"
)
}
}
impl SerializerRegistry for EmptySerializerRegistry {}

/// Describes which SQL statements can be run.
///
Expand Down
40 changes: 35 additions & 5 deletions datafusion/expr/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use crate::expr_rewriter::FunctionRewrite;
use crate::planner::ExprPlanner;
use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
use crate::{AggregateUDF, ScalarUDF, TableSource, UserDefinedLogicalNode, WindowUDF};
use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result};
use std::collections::HashSet;
use std::fmt::Debug;
Expand Down Expand Up @@ -123,22 +123,52 @@ pub trait FunctionRegistry {
}
}

/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]
/// and custom table providers for which the name alone is meaningless in the target
/// execution context, e.g. UDTFs, manually registered tables etc.
pub trait SerializerRegistry: Debug + Send + Sync {
/// Serialize this node to a byte array. This serialization should not include
/// input plans.
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>>;
) -> Result<Vec<u8>> {
not_impl_err!(
"Serializing user defined logical plan node `{}` is not supported",
node.name()
)
}

/// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from
/// bytes.
fn deserialize_logical_plan(
&self,
name: &str,
bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>>;
_bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
not_impl_err!(
"Deserializing user defined logical plan node `{name}` is not supported"
)
}

/// Serialized table definition for UDTFs or manually registered table providers that can't be
/// marshaled by reference. Should return some benign error for regular tables that can be
/// found/restored by name in the destination execution context.
fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result<Vec<u8>> {
not_impl_err!("No custom table support")
}

/// Deserialize the custom table with the given name.
/// Note: more often than not, the name can't be used as a discriminator if multiple different
/// `TableSource` and/or `TableProvider` implementations are expected (this is particularly true
/// for UDTFs in DataFusion, which are always registered under the same name: `tmp_table`).
fn deserialize_custom_table(
&self,
name: &str,
_bytes: &[u8],
) -> Result<Arc<dyn TableSource>> {
not_impl_err!("Deserializing custom table `{name}` is not supported")
}
}

/// A [`FunctionRegistry`] that uses in memory [`HashMap`]s
Expand Down
52 changes: 49 additions & 3 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};

use datafusion::logical_expr::{
Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension,
LogicalPlan, Operator, Projection, SortExpr, Subquery, TryCast, Values,
LogicalPlan, Operator, Projection, SortExpr, Subquery, TableScan, TryCast, Values,
};
use substrait::proto::aggregate_rel::Grouping;
use substrait::proto::expression as substrait_expression;
Expand Down Expand Up @@ -86,6 +86,7 @@ use substrait::proto::expression::{
SingularOrList, SwitchExpression, WindowFunction,
};
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
use substrait::proto::read_rel::ExtensionTable;
use substrait::proto::rel_common::{Emit, EmitKind};
use substrait::proto::set_rel::SetOp;
use substrait::proto::{
Expand Down Expand Up @@ -438,6 +439,22 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
user_defined_literal.type_reference
)
}

fn consume_extension_table(
&self,
extension_table: &ExtensionTable,
_schema: &DFSchema,
_projection: &Option<MaskExpression>,
) -> Result<LogicalPlan> {
if let Some(ext_detail) = extension_table.detail.as_ref() {
substrait_err!(
"Missing handler for extension table: {}",
&ext_detail.type_url
)
} else {
substrait_err!("Unexpected empty detail in ExtensionTable")
}
}
}

/// Convert Substrait Rel to DataFusion DataFrame
Expand Down Expand Up @@ -559,6 +576,32 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
Ok(LogicalPlan::Extension(Extension { node: plan }))
}

fn consume_extension_table(
&self,
extension_table: &ExtensionTable,
schema: &DFSchema,
projection: &Option<MaskExpression>,
) -> Result<LogicalPlan> {
if let Some(ext_detail) = &extension_table.detail {
let source = self
.state
.serializer_registry()
.deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)?;
let table_name = ext_detail
.type_url
.rsplit_once('/')
.map(|(_, name)| name)
.unwrap_or(&ext_detail.type_url);
let table_scan = TableScan::try_new(table_name, source, None, vec![], None)?;
let plan = LogicalPlan::TableScan(table_scan);
ensure_schema_compatibility(plan.schema(), schema.clone())?;
let schema = apply_masking(schema.clone(), projection)?;
apply_projection(plan, schema)
} else {
substrait_err!("Unexpected empty detail in ExtensionTable")
}
}
}

// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which
Expand Down Expand Up @@ -1449,8 +1492,11 @@ pub async fn from_read_rel(
)
.await
}
_ => {
not_impl_err!("Unsupported ReadType: {:?}", read.read_type)
Some(ReadType::ExtensionTable(ext)) => {
consumer.consume_extension_table(ext, &substrait_schema, &read.projection)
}
None => {
substrait_err!("Unexpected empty read_type")
}
}
}
Expand Down
140 changes: 132 additions & 8 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ use substrait::proto::expression::literal::{
};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::read_rel::VirtualTable;
use substrait::proto::read_rel::{ExtensionTable, VirtualTable};
use substrait::proto::rel_common::EmitKind;
use substrait::proto::rel_common::EmitKind::Emit;
use substrait::proto::{
Expand Down Expand Up @@ -211,6 +211,23 @@ pub fn to_substrait_rel(
let table_schema = scan.source.schema().to_dfschema_ref()?;
let base_schema = to_substrait_named_struct(&table_schema)?;

let table = if let Ok(bytes) = state
.serializer_registry()
.serialize_custom_table(scan.source.as_ref())
{
ReadType::ExtensionTable(ExtensionTable {
detail: Some(ProtoAny {
type_url: scan.table_name.to_string(),
value: bytes.into(),
}),
})
} else {
ReadType::NamedTable(NamedTable {
names: scan.table_name.to_vec(),
advanced_extension: None,
})
};

Ok(Box::new(Rel {
rel_type: Some(RelType::Read(Box::new(ReadRel {
common: None,
Expand All @@ -219,10 +236,7 @@ pub fn to_substrait_rel(
best_effort_filter: None,
projection,
advanced_extension: None,
read_type: Some(ReadType::NamedTable(NamedTable {
names: scan.table_name.to_vec(),
advanced_extension: None,
})),
read_type: Some(table),
}))),
}))
}
Expand Down Expand Up @@ -2238,17 +2252,21 @@ mod test {
use super::*;
use crate::logical_plan::consumer::{
from_substrait_extended_expr, from_substrait_literal_without_names,
from_substrait_named_struct, from_substrait_type_without_names,
DefaultSubstraitConsumer,
from_substrait_named_struct, from_substrait_plan,
from_substrait_type_without_names, DefaultSubstraitConsumer,
};
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use datafusion::arrow::array::{
GenericListArray, Int64Builder, MapBuilder, StringBuilder,
};
use datafusion::arrow::datatypes::{Field, Fields, Schema};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::common::DFSchema;
use datafusion::common::{assert_contains, DFSchema};
use datafusion::datasource::empty::EmptyTable;
use datafusion::datasource::{DefaultTableSource, TableProvider};
use datafusion::execution::{SessionState, SessionStateBuilder};
use datafusion::logical_expr::registry::SerializerRegistry;
use datafusion::logical_expr::TableSource;
use datafusion::prelude::SessionContext;
use std::sync::OnceLock;

Expand Down Expand Up @@ -2585,4 +2603,110 @@ mod test {

assert!(matches!(err, Err(DataFusionError::SchemaError(_, _))));
}

#[tokio::test]
async fn round_trip_extension_table() {
const TABLE_NAME: &str = "custom_table";
const SERIALIZED: &[u8] = "table definition".as_bytes();

fn custom_table() -> Arc<dyn TableProvider> {
Arc::new(EmptyTable::new(Arc::new(Schema::new([
Arc::new(Field::new("id", DataType::Int32, false)),
Arc::new(Field::new("name", DataType::Utf8, false)),
]))))
}

#[derive(Debug)]
struct Registry;
impl SerializerRegistry for Registry {
fn serialize_custom_table(&self, table: &dyn TableSource) -> Result<Vec<u8>> {
if table.schema() == custom_table().schema() {
Ok(SERIALIZED.to_vec())
} else {
Err(DataFusionError::Internal("Not our table".into()))
}
}
fn deserialize_custom_table(
&self,
name: &str,
bytes: &[u8],
) -> Result<Arc<dyn TableSource>> {
if name == TABLE_NAME && bytes == SERIALIZED {
Ok(Arc::new(DefaultTableSource::new(custom_table())))
} else {
panic!("Unexpected extension table: {name}");
}
}
}

async fn round_trip_logical_plans(
local: &SessionContext,
remote: &SessionContext,
) -> Result<()> {
local.register_table(TABLE_NAME, custom_table())?;
remote.table_provider(TABLE_NAME).await.expect_err(
"The remote context is not supposed to know about custom_table",
);
let initial_plan = local
.sql(&format!("select id from {TABLE_NAME}"))
.await?
.logical_plan()
.clone();

// write substrait locally
let substrait = to_substrait_plan(&initial_plan, &local.state())?;

// read substrait remotely
// since we know there's no `custom_table` registered in the remote context, this will only succeed
// if our table got encoded as an ExtensionTable and is now decoded back to a table source.
let restored = from_substrait_plan(&remote.state(), &substrait).await?;
assert_contains!(
// confirm that the Substrait plan contains our custom_table as an ExtensionTable
serde_json::to_string(substrait.as_ref()).unwrap(),
format!(r#""extensionTable":{{"detail":{{"typeUrl":"{TABLE_NAME}","#)
);
remote // make sure the restored plan is fully working in the remote context
.execute_logical_plan(restored.clone())
.await?
.collect()
.await
.expect("Restored plan cannot be executed remotely");
assert_eq!(
// check that the restored plan is functionally equivalent (and almost identical) to the initial one
initial_plan.to_string(),
restored.to_string().replace(
// substrait will add an explicit full-schema projection if the original table had none
&format!("TableScan: {TABLE_NAME} projection=[id, name]"),
&format!("TableScan: {TABLE_NAME}"),
)
);
Ok(())
}

// take 1
let failed_attempt =
round_trip_logical_plans(&SessionContext::new(), &SessionContext::new())
.await
.expect_err(
"The round trip should fail in the absence of a SerializerRegistry",
);
assert_contains!(
failed_attempt.message(),
format!("No table named '{TABLE_NAME}'")
);

// take 2
fn proper_context() -> SessionContext {
SessionContext::new_with_state(
SessionStateBuilder::new()
// This will transport our custom_table as a Substrait ExtensionTable
.with_serializer_registry(Arc::new(Registry))
.build(),
)
}

round_trip_logical_plans(&proper_context(), &proper_context())
.await
.expect("Local plan could not be restored remotely");
}
}

0 comments on commit bfda3d2

Please sign in to comment.