Skip to content

Commit

Permalink
[substrait] Add support for ExtensionTable
Browse files Browse the repository at this point in the history
  • Loading branch information
ccciudatu committed Dec 13, 2024
1 parent 8b6daaf commit 081e886
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 25 deletions.
51 changes: 43 additions & 8 deletions datafusion/expr/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
use crate::expr_rewriter::FunctionRewrite;
use crate::planner::ExprPlanner;
use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result};
use crate::{AggregateUDF, ScalarUDF, TableSource, UserDefinedLogicalNode, WindowUDF};
use datafusion_common::{not_impl_err, plan_datafusion_err, DataFusionError, HashMap, Result};
use std::collections::HashSet;
use std::fmt::Debug;
use std::sync::Arc;
Expand Down Expand Up @@ -123,22 +123,57 @@ pub trait FunctionRegistry {
}
}

/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]
/// and custom table providers for which the name alone is meaningless in the target
/// execution context, e.g. UDTFs, manually registered tables etc.
pub trait SerializerRegistry: Debug + Send + Sync {
/// Serialize this node to a byte array. This serialization should not include
/// input plans.
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>>;
_node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>> {
Err(DataFusionError::Plan("UserDefinedLogicalNode serialization not supported".into()))
}

/// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from
/// bytes.
fn deserialize_logical_plan(
&self,
name: &str,
bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>>;
_name: &str,
_bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
Err(DataFusionError::Plan("UserDefinedLogicalNode deserialization not supported".into()))
}

/// Binary representation for custom tables, to be converted to substrait extension tables.
/// Should only return success for table implementations that cannot be found by name
/// in the destination execution context, such as UDTFs or manually registered table providers.
fn serialize_custom_table(
&self,
_table: &dyn TableSource,
) -> Result<Vec<u8>> {
Err(DataFusionError::Plan("Custom table serialization not supported".into()))
}

/// Deserialize the custom table with the given name.
/// The name may not be useful as a discriminator if multiple UDTF/TableProvider
/// implementations are expected. This is particularly true for UDTFs in DataFusion,
/// which are always registered under the same name: `tmp_table`, so one should
/// use the binary payload to distinguish between multiple table types.
/// A potential future improvement would be to return a (name, bytes) tuple from
/// [SerializerRegistry::serialize_custom_table] to allow the implementors to assign
/// different names to different table provider implementations (e.g. in the case of proto,
/// by using the actual protobuf `type_url`).
/// But this would mean the table names in the restored plan may no longer match
/// the original ones.
fn deserialize_custom_table(
&self,
_name: &str,
_bytes: &[u8],
) -> Result<Arc<dyn TableSource>> {
Err(DataFusionError::Plan("Custom table deserialization not supported".into()))
}
}

/// A [`FunctionRegistry`] that uses in memory [`HashMap`]s
Expand Down
28 changes: 22 additions & 6 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ use datafusion::common::{
use datafusion::datasource::provider_as_source;
use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};

use datafusion::logical_expr::{
Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan,
Operator, Projection, SortExpr, TryCast, Values,
};
use datafusion::logical_expr::{Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, TableScan, TryCast, Values};
use substrait::proto::aggregate_rel::Grouping;
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
use substrait::proto::expression_reference::ExprType;
Expand Down Expand Up @@ -994,8 +991,27 @@ pub async fn from_substrait_rel(
)
.await
}
_ => {
not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type)
Some(ReadType::ExtensionTable(ext)) => {
if let Some(ext_detail) = &ext.detail {
let source = state.serializer_registry()
.deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)?;
let table_name = if let Some((_, name)) = ext_detail.type_url.rsplit_once('/') {
name
} else {
&ext_detail.type_url
};
let plan = LogicalPlan::TableScan(
TableScan::try_new(table_name, source, None, vec![], None)?
);
let schema = apply_masking(substrait_schema, &read.projection)?;
ensure_schema_compatability(plan.schema(), schema.clone())?;
apply_projection(plan, schema)
} else {
substrait_err!("Unexpected empty detail in ExtensionTable")
}
},
None => {
substrait_err!("Unexpected empty read_type")
}
}
}
Expand Down
123 changes: 112 additions & 11 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ use substrait::proto::expression::literal::{
};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::read_rel::VirtualTable;
use substrait::proto::read_rel::{ExtensionTable, VirtualTable};
use substrait::proto::rel_common::EmitKind;
use substrait::proto::rel_common::EmitKind::Emit;
use substrait::proto::{
Expand Down Expand Up @@ -211,6 +211,22 @@ pub fn to_substrait_rel(
let table_schema = scan.source.schema().to_dfschema_ref()?;
let base_schema = to_substrait_named_struct(&table_schema)?;

let table = if let Ok(bytes) = state
.serializer_registry()
.serialize_custom_table(scan.source.as_ref()) {
ReadType::ExtensionTable(ExtensionTable {
detail: Some(ProtoAny {
type_url: scan.table_name.to_string(),
value: bytes.into(),
})
})
} else {
ReadType::NamedTable(NamedTable {
names: scan.table_name.to_vec(),
advanced_extension: None,
})
};

Ok(Box::new(Rel {
rel_type: Some(RelType::Read(Box::new(ReadRel {
common: None,
Expand All @@ -219,10 +235,7 @@ pub fn to_substrait_rel(
best_effort_filter: None,
projection,
advanced_extension: None,
read_type: Some(ReadType::NamedTable(NamedTable {
names: scan.table_name.to_vec(),
advanced_extension: None,
})),
read_type: Some(table),
}))),
}))
}
Expand Down Expand Up @@ -2200,18 +2213,20 @@ fn substrait_field_ref(index: usize) -> Result<Expression> {
#[cfg(test)]
mod test {
use super::*;
use crate::logical_plan::consumer::{
from_substrait_extended_expr, from_substrait_literal_without_names,
from_substrait_named_struct, from_substrait_type_without_names,
};
use crate::logical_plan::consumer::{from_substrait_extended_expr, from_substrait_literal_without_names, from_substrait_named_struct, from_substrait_plan, from_substrait_type_without_names};
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use datafusion::arrow::array::{
GenericListArray, Int64Builder, MapBuilder, StringBuilder,
};
use datafusion::arrow::datatypes::{Field, Fields, Schema};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::common::DFSchema;
use datafusion::execution::SessionStateBuilder;
use datafusion::common::{assert_contains, DFSchema};
use datafusion::datasource::{DefaultTableSource, TableProvider};
use datafusion::datasource::empty::EmptyTable;
use datafusion::execution::registry::SerializerRegistry;
use datafusion::execution::{SessionState, SessionStateBuilder};
use datafusion::logical_expr::TableSource;
use datafusion::prelude::SessionContext;

#[test]
fn round_trip_literals() -> Result<()> {
Expand Down Expand Up @@ -2518,4 +2533,90 @@ mod test {

assert!(matches!(err, Err(DataFusionError::SchemaError(_, _))));
}

#[tokio::test]
async fn round_trip_extension_table() -> Result<()> {
#[derive(Debug)]
struct Registry {
table: Arc<dyn TableProvider>,
}
impl SerializerRegistry for Registry {
fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result<Vec<u8>> {
Ok("expected payload".as_bytes().to_vec())
}
fn deserialize_custom_table(&self, _name: &str, _bytes: &[u8]) -> Result<Arc<dyn TableSource>> {
Ok(Arc::new(DefaultTableSource::new(self.table.clone())))
}
}

async fn round_trip_logical_plans(
local: &SessionContext,
remote: &SessionContext,
table: Arc<dyn TableProvider>
) -> Result<(LogicalPlan, LogicalPlan)> {
local.register_table("custom_table", table)?;
let initial_plan = local.sql("select id from custom_table")
.await?
.logical_plan()
.clone();

// write substrait locally
let substrait = to_substrait_plan(&initial_plan, &local.state())?;

// read substrait remotely
// this will only succeed if our custom_table was encoded as an ExtensionTable,
// since there's no `custom_table` registered in the remote context.
let restored = from_substrait_plan(&remote.state(), &substrait).await?;
assert_contains!(
serde_json::to_string(substrait.as_ref()).unwrap(),
// value == base64("expected payload")
r#""extensionTable":{"detail":{"typeUrl":"custom_table","value":"ZXhwZWN0ZWQgcGF5bG9hZA=="}}"#
);
Ok((initial_plan, restored))
}

let empty = Arc::new(EmptyTable::new(Arc::new(
Schema::new([
Arc::new(Field::new("id", DataType::Int32, false)),
Arc::new(Field::new("name", DataType::Utf8, false)),
])
)));

let first_attempt = round_trip_logical_plans(
&SessionContext::new(),
&SessionContext::new(),
empty.clone()
).await;
assert_eq!(
first_attempt.unwrap_err().to_string(),
"Error during planning: No table named 'custom_table'"
);
fn proper_state(table: Arc<dyn TableProvider>) -> SessionState {
SessionStateBuilder::new()
.with_default_features()
.with_serializer_registry(Arc::new(Registry { table }))
.build()
}
let local = SessionContext::new_with_state(proper_state(empty.clone()));
let remote = SessionContext::new_with_state(proper_state(empty.clone()));

let (initial_plan, restored) = round_trip_logical_plans(&local, &remote, empty)
.await
.expect("Should restore the substrait plan as datafusion logical plan");

assert_eq!(
initial_plan.to_string(),
restored.to_string()
// substrait will add an explicit projection with the full schema
.replace(
"TableScan: custom_table projection=[id, name]",
"TableScan: custom_table"
)
);
assert_eq!(
local.execute_logical_plan(initial_plan).await?.collect().await?,
remote.execute_logical_plan(restored).await?.collect().await?,
);
Ok(())
}
}

0 comments on commit 081e886

Please sign in to comment.