Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[substrait] Add support for ExtensionTable #13772

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 2 additions & 22 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ use datafusion_expr::{
expr_rewriter::FunctionRewrite,
logical_plan::{DdlStatement, Statement},
planner::ExprPlanner,
Expr, UserDefinedLogicalNode, WindowUDF,
Expr, WindowUDF,
};

// backwards compatibility
Expand Down Expand Up @@ -1679,27 +1679,7 @@ pub enum RegisterFunction {
#[derive(Debug)]
pub struct EmptySerializerRegistry;

impl SerializerRegistry for EmptySerializerRegistry {
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>> {
not_impl_err!(
"Serializing user defined logical plan node `{}` is not supported",
node.name()
)
}

fn deserialize_logical_plan(
&self,
name: &str,
_bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
not_impl_err!(
"Deserializing user defined logical plan node `{name}` is not supported"
)
}
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SerializerRegistry trait now has two more methods for handling tables (with default implementations for backwards compatibility), so it makes sense for the existing methods to have default implementations as well.
This will allow implementors to conveniently implement the trait for user-defined logical nodes only or for tables only.
Since the implementations here are perfect as trait defaults, this PR just moves them into the trait itself.

impl SerializerRegistry for EmptySerializerRegistry {}

/// Describes which SQL statements can be run.
///
Expand Down
40 changes: 35 additions & 5 deletions datafusion/expr/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use crate::expr_rewriter::FunctionRewrite;
use crate::planner::ExprPlanner;
use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
use crate::{AggregateUDF, ScalarUDF, TableSource, UserDefinedLogicalNode, WindowUDF};
use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result};
use std::collections::HashSet;
use std::fmt::Debug;
Expand Down Expand Up @@ -123,22 +123,52 @@ pub trait FunctionRegistry {
}
}

/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]
/// and custom table providers for which the name alone is meaningless in the target
/// execution context, e.g. UDTFs, manually registered tables etc.
pub trait SerializerRegistry: Debug + Send + Sync {
/// Serialize this node to a byte array. This serialization should not include
/// input plans.
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>>;
) -> Result<Vec<u8>> {
not_impl_err!(
"Serializing user defined logical plan node `{}` is not supported",
node.name()
)
}

/// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from
/// bytes.
fn deserialize_logical_plan(
&self,
name: &str,
bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>>;
_bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
not_impl_err!(
"Deserializing user defined logical plan node `{name}` is not supported"
)
}

/// Serialized table definition for UDTFs or manually registered table providers that can't be
/// marshaled by reference. Should return some benign error for regular tables that can be
/// found/restored by name in the destination execution context.
fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result<Vec<u8>> {
not_impl_err!("No custom table support")
}

/// Deserialize the custom table with the given name.
/// Note: more often than not, the name can't be used as a discriminator if multiple different
/// `TableSource` and/or `TableProvider` implementations are expected (this is particularly true
/// for UDTFs in DataFusion, which are always registered under the same name: `tmp_table`).
fn deserialize_custom_table(
&self,
name: &str,
_bytes: &[u8],
) -> Result<Arc<dyn TableSource>> {
not_impl_err!("Deserializing custom table `{name}` is not supported")
}
}

/// A [`FunctionRegistry`] that uses in memory [`HashMap`]s
Expand Down
52 changes: 49 additions & 3 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};

use datafusion::logical_expr::{
Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension,
LogicalPlan, Operator, Projection, SortExpr, Subquery, TryCast, Values,
LogicalPlan, Operator, Projection, SortExpr, Subquery, TableScan, TryCast, Values,
};
use substrait::proto::aggregate_rel::Grouping;
use substrait::proto::expression as substrait_expression;
Expand Down Expand Up @@ -86,6 +86,7 @@ use substrait::proto::expression::{
SingularOrList, SwitchExpression, WindowFunction,
};
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
use substrait::proto::read_rel::ExtensionTable;
use substrait::proto::rel_common::{Emit, EmitKind};
use substrait::proto::set_rel::SetOp;
use substrait::proto::{
Expand Down Expand Up @@ -438,6 +439,22 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
user_defined_literal.type_reference
)
}

fn consume_extension_table(
&self,
extension_table: &ExtensionTable,
_schema: &DFSchema,
_projection: &Option<MaskExpression>,
) -> Result<LogicalPlan> {
Copy link
Contributor

@vbarua vbarua Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose to pass the schema and projection to this method (instead of keeping the TableScan postprocessing in from_read_rel) to allow custom implementations to use that information for fully restoring their custom tables if needed.

I like where your head is at with this, but I almost want to go further. You already called out:

leverage the available ReadRel information that is currently unused (e.g. filtering, advanced extensions etc.)

Maybe the interface for this should just be:

    fn consume_extension_table(
        &self,
        read_rel: &ReadRel
        extension_table: &ExtensionTable) -> Result<LogicalPlan>

which will be future proofed for if fields are ever added to the ReadRel, and also provides access to common fields on the ReadRel.

We could even go further and add

    fn consume_named_table(
        &self,
        read_rel: &ReadRel
        named_table: &NamedTable) -> Result<LogicalPlan>

    fn consume_virtual_table(
        &self,
        read_rel: &ReadRel
        named_table: &VirtualTable) -> Result<LogicalPlan>

to make it easier to customize behaviour for specific read_types. This last idea might be better as it's own PR, as we would need to factor out some of the code in from_read_rel into functions to be re-used across the new helpers.

Copy link
Contributor Author

@ccciudatu ccciudatu Dec 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see your point, especially as I've been toying with the same idea myself.
However, I found the intrinsic redundancy a bit error prone: at least in theory, this interface allows the ReadRel to contain a table other than the one passed in as the second argument.
Eliminating this redundancy would end up with the exact same signature as consume_read, which renders the new helper(s) superfluous.

Copy link
Contributor

@Blizzara Blizzara Dec 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is maybe slightly against the stated goal of "to allow custom implementations to use that information for fully restoring their custom tables if needed", but I'm not sure why someone would need custom impl for that behavior based on the read type?

How about something like:

from_read_rel(consumer: &SubstraitConsumer, read: &ReadRel, ..) -> .. {
   let plan = match read.type {
     Some(NamedTable(nt)) => consumer.consume_named_table(nt)
     Some(VirtualTable(vt)) =>  consumer.consume_virtual_table(vt),
     Some(ExtensionTable(et)) => consumer.consume_extension_table(et),
    ...
  }
  ensure_schema_compatibility(plan.schema(), schema.clone())?;
  let schema = apply_masking(schema, projection)?;
  apply_projection(plan, schema)
}

That way the ReadRel handling doesn't need to happen in multiple places, its projection and schema are by default handled the same way for all relations (which I'd think they should?), but if a user wants they can easily override the whole from_read_rel (or more specifically, SubstraitConsumer::consume_read) and compose their desired result.

(Actually, dunno if there's reason to have consumer.consume_named_table() and .consume_virtual_table(), given probably nothing else than consume_read calls those, their default impl should be good enough and if not they can always be overridden by implementing consume_read. But having consumer.consume_extension_table makes sense as an easy way to specify that behavior.)

if let Some(ext_detail) = extension_table.detail.as_ref() {
substrait_err!(
"Missing handler for extension table: {}",
&ext_detail.type_url
)
} else {
substrait_err!("Unexpected empty detail in ExtensionTable")
}
}
}

/// Convert Substrait Rel to DataFusion DataFrame
Expand Down Expand Up @@ -559,6 +576,32 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
Ok(LogicalPlan::Extension(Extension { node: plan }))
}

fn consume_extension_table(
&self,
extension_table: &ExtensionTable,
schema: &DFSchema,
projection: &Option<MaskExpression>,
) -> Result<LogicalPlan> {
if let Some(ext_detail) = &extension_table.detail {
let source = self
.state
.serializer_registry()
.deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)?;
let table_name = ext_detail
.type_url
.rsplit_once('/')
.map(|(_, name)| name)
.unwrap_or(&ext_detail.type_url);
let table_scan = TableScan::try_new(table_name, source, None, vec![], None)?;
let plan = LogicalPlan::TableScan(table_scan);
ensure_schema_compatibility(plan.schema(), schema.clone())?;
let schema = apply_masking(schema.clone(), projection)?;
apply_projection(plan, schema)
} else {
substrait_err!("Unexpected empty detail in ExtensionTable")
}
}
}

// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which
Expand Down Expand Up @@ -1449,8 +1492,11 @@ pub async fn from_read_rel(
)
.await
}
_ => {
not_impl_err!("Unsupported ReadType: {:?}", read.read_type)
Some(ReadType::ExtensionTable(ext)) => {
consumer.consume_extension_table(ext, &substrait_schema, &read.projection)
}
None => {
substrait_err!("Unexpected empty read_type")
}
}
}
Expand Down
140 changes: 132 additions & 8 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ use substrait::proto::expression::literal::{
};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::read_rel::VirtualTable;
use substrait::proto::read_rel::{ExtensionTable, VirtualTable};
use substrait::proto::rel_common::EmitKind;
use substrait::proto::rel_common::EmitKind::Emit;
use substrait::proto::{
Expand Down Expand Up @@ -211,6 +211,23 @@ pub fn to_substrait_rel(
let table_schema = scan.source.schema().to_dfschema_ref()?;
let base_schema = to_substrait_named_struct(&table_schema)?;

let table = if let Ok(bytes) = state
.serializer_registry()
.serialize_custom_table(scan.source.as_ref())
{
ReadType::ExtensionTable(ExtensionTable {
detail: Some(ProtoAny {
type_url: scan.table_name.to_string(),
value: bytes.into(),
}),
})
} else {
ReadType::NamedTable(NamedTable {
names: scan.table_name.to_vec(),
advanced_extension: None,
})
};

Ok(Box::new(Rel {
rel_type: Some(RelType::Read(Box::new(ReadRel {
common: None,
Expand All @@ -219,10 +236,7 @@ pub fn to_substrait_rel(
best_effort_filter: None,
projection,
advanced_extension: None,
read_type: Some(ReadType::NamedTable(NamedTable {
names: scan.table_name.to_vec(),
advanced_extension: None,
})),
read_type: Some(table),
}))),
}))
}
Expand Down Expand Up @@ -2238,17 +2252,21 @@ mod test {
use super::*;
use crate::logical_plan::consumer::{
from_substrait_extended_expr, from_substrait_literal_without_names,
from_substrait_named_struct, from_substrait_type_without_names,
DefaultSubstraitConsumer,
from_substrait_named_struct, from_substrait_plan,
from_substrait_type_without_names, DefaultSubstraitConsumer,
};
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use datafusion::arrow::array::{
GenericListArray, Int64Builder, MapBuilder, StringBuilder,
};
use datafusion::arrow::datatypes::{Field, Fields, Schema};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::common::DFSchema;
use datafusion::common::{assert_contains, DFSchema};
use datafusion::datasource::empty::EmptyTable;
use datafusion::datasource::{DefaultTableSource, TableProvider};
use datafusion::execution::{SessionState, SessionStateBuilder};
use datafusion::logical_expr::registry::SerializerRegistry;
use datafusion::logical_expr::TableSource;
use datafusion::prelude::SessionContext;
use std::sync::OnceLock;

Expand Down Expand Up @@ -2585,4 +2603,110 @@ mod test {

assert!(matches!(err, Err(DataFusionError::SchemaError(_, _))));
}

#[tokio::test]
async fn round_trip_extension_table() {
const TABLE_NAME: &str = "custom_table";
const SERIALIZED: &[u8] = "table definition".as_bytes();

fn custom_table() -> Arc<dyn TableProvider> {
Arc::new(EmptyTable::new(Arc::new(Schema::new([
Arc::new(Field::new("id", DataType::Int32, false)),
Arc::new(Field::new("name", DataType::Utf8, false)),
]))))
}

#[derive(Debug)]
struct Registry;
impl SerializerRegistry for Registry {
fn serialize_custom_table(&self, table: &dyn TableSource) -> Result<Vec<u8>> {
if table.schema() == custom_table().schema() {
Ok(SERIALIZED.to_vec())
} else {
Err(DataFusionError::Internal("Not our table".into()))
}
}
fn deserialize_custom_table(
&self,
name: &str,
bytes: &[u8],
) -> Result<Arc<dyn TableSource>> {
if name == TABLE_NAME && bytes == SERIALIZED {
Ok(Arc::new(DefaultTableSource::new(custom_table())))
} else {
panic!("Unexpected extension table: {name}");
}
}
}

async fn round_trip_logical_plans(
local: &SessionContext,
remote: &SessionContext,
) -> Result<()> {
local.register_table(TABLE_NAME, custom_table())?;
remote.table_provider(TABLE_NAME).await.expect_err(
"The remote context is not supposed to know about custom_table",
);
let initial_plan = local
.sql(&format!("select id from {TABLE_NAME}"))
.await?
.logical_plan()
.clone();

// write substrait locally
let substrait = to_substrait_plan(&initial_plan, &local.state())?;

// read substrait remotely
// since we know there's no `custom_table` registered in the remote context, this will only succeed
// if our table got encoded as an ExtensionTable and is now decoded back to a table source.
let restored = from_substrait_plan(&remote.state(), &substrait).await?;
assert_contains!(
// confirm that the Substrait plan contains our custom_table as an ExtensionTable
serde_json::to_string(substrait.as_ref()).unwrap(),
format!(r#""extensionTable":{{"detail":{{"typeUrl":"{TABLE_NAME}","#)
);
remote // make sure the restored plan is fully working in the remote context
.execute_logical_plan(restored.clone())
.await?
.collect()
.await
.expect("Restored plan cannot be executed remotely");
assert_eq!(
// check that the restored plan is functionally equivalent (and almost identical) to the initial one
initial_plan.to_string(),
restored.to_string().replace(
// substrait will add an explicit full-schema projection if the original table had none
&format!("TableScan: {TABLE_NAME} projection=[id, name]"),
&format!("TableScan: {TABLE_NAME}"),
)
);
Ok(())
}

// take 1
let failed_attempt =
round_trip_logical_plans(&SessionContext::new(), &SessionContext::new())
.await
.expect_err(
"The round trip should fail in the absence of a SerializerRegistry",
);
assert_contains!(
failed_attempt.message(),
format!("No table named '{TABLE_NAME}'")
);

// take 2
fn proper_context() -> SessionContext {
SessionContext::new_with_state(
SessionStateBuilder::new()
// This will transport our custom_table as a Substrait ExtensionTable
.with_serializer_registry(Arc::new(Registry))
.build(),
)
}

round_trip_logical_plans(&proper_context(), &proper_context())
.await
.expect("Local plan could not be restored remotely");
}
}
Loading