Skip to content

Commit

Permalink
cargo fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
ccciudatu committed Dec 13, 2024
1 parent 081e886 commit a65e926
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 37 deletions.
27 changes: 17 additions & 10 deletions datafusion/expr/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
use crate::expr_rewriter::FunctionRewrite;
use crate::planner::ExprPlanner;
use crate::{AggregateUDF, ScalarUDF, TableSource, UserDefinedLogicalNode, WindowUDF};
use datafusion_common::{not_impl_err, plan_datafusion_err, DataFusionError, HashMap, Result};
use datafusion_common::{
not_impl_err, plan_datafusion_err, DataFusionError, HashMap, Result,
};
use std::collections::HashSet;
use std::fmt::Debug;
use std::sync::Arc;
Expand Down Expand Up @@ -133,7 +135,9 @@ pub trait SerializerRegistry: Debug + Send + Sync {
&self,
_node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>> {
Err(DataFusionError::Plan("UserDefinedLogicalNode serialization not supported".into()))
Err(DataFusionError::Plan(
"UserDefinedLogicalNode serialization not supported".into(),
))
}

/// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from
Expand All @@ -143,24 +147,25 @@ pub trait SerializerRegistry: Debug + Send + Sync {
_name: &str,
_bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
Err(DataFusionError::Plan("UserDefinedLogicalNode deserialization not supported".into()))
Err(DataFusionError::Plan(
"UserDefinedLogicalNode deserialization not supported".into(),
))
}

/// Binary representation for custom tables, to be converted to substrait extension tables.
/// Should only return success for table implementations that cannot be found by name
/// in the destination execution context, such as UDTFs or manually registered table providers.
fn serialize_custom_table(
&self,
_table: &dyn TableSource,
) -> Result<Vec<u8>> {
Err(DataFusionError::Plan("Custom table serialization not supported".into()))
fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result<Vec<u8>> {
Err(DataFusionError::Plan(
"Custom table serialization not supported".into(),
))
}

/// Deserialize the custom table with the given name.
/// The name may not be useful as a discriminator if multiple UDTF/TableProvider
/// implementations are expected. This is particularly true for UDTFs in DataFusion,
/// which are always registered under the same name: `tmp_table`, so one should
/// use the binary payload to distinguish between multiple table types.
/// use the binary payload to distinguish between multiple potential table types.
/// A potential future improvement would be to return a (name, bytes) tuple from
/// [SerializerRegistry::serialize_custom_table] to allow the implementors to assign
/// different names to different table provider implementations (e.g. in the case of proto,
Expand All @@ -172,7 +177,9 @@ pub trait SerializerRegistry: Debug + Send + Sync {
_name: &str,
_bytes: &[u8],
) -> Result<Arc<dyn TableSource>> {
Err(DataFusionError::Plan("Custom table deserialization not supported".into()))
Err(DataFusionError::Plan(
"Custom table deserialization not supported".into(),
))
}
}

Expand Down
28 changes: 20 additions & 8 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ use datafusion::common::{
use datafusion::datasource::provider_as_source;
use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};

use datafusion::logical_expr::{Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, TableScan, TryCast, Values};
use datafusion::logical_expr::{
Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan,
Operator, Projection, SortExpr, TableScan, TryCast, Values,
};
use substrait::proto::aggregate_rel::Grouping;
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
use substrait::proto::expression_reference::ExprType;
Expand Down Expand Up @@ -993,23 +996,32 @@ pub async fn from_substrait_rel(
}
Some(ReadType::ExtensionTable(ext)) => {
if let Some(ext_detail) = &ext.detail {
let source = state.serializer_registry()
.deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)?;
let table_name = if let Some((_, name)) = ext_detail.type_url.rsplit_once('/') {
let source =
state.serializer_registry().deserialize_custom_table(
&ext_detail.type_url,
&ext_detail.value,
)?;
let table_name = if let Some((_, name)) =
ext_detail.type_url.rsplit_once('/')
{
name
} else {
&ext_detail.type_url
};
let plan = LogicalPlan::TableScan(
TableScan::try_new(table_name, source, None, vec![], None)?
);
let plan = LogicalPlan::TableScan(TableScan::try_new(
table_name,
source,
None,
vec![],
None,
)?);
let schema = apply_masking(substrait_schema, &read.projection)?;
ensure_schema_compatability(plan.schema(), schema.clone())?;
apply_projection(plan, schema)
} else {
substrait_err!("Unexpected empty detail in ExtensionTable")
}
},
}
None => {
substrait_err!("Unexpected empty read_type")
}
Expand Down
59 changes: 40 additions & 19 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,13 @@ pub fn to_substrait_rel(

let table = if let Ok(bytes) = state
.serializer_registry()
.serialize_custom_table(scan.source.as_ref()) {
.serialize_custom_table(scan.source.as_ref())
{
ReadType::ExtensionTable(ExtensionTable {
detail: Some(ProtoAny {
type_url: scan.table_name.to_string(),
value: bytes.into(),
})
}),
})
} else {
ReadType::NamedTable(NamedTable {
Expand Down Expand Up @@ -2213,16 +2214,20 @@ fn substrait_field_ref(index: usize) -> Result<Expression> {
#[cfg(test)]
mod test {
use super::*;
use crate::logical_plan::consumer::{from_substrait_extended_expr, from_substrait_literal_without_names, from_substrait_named_struct, from_substrait_plan, from_substrait_type_without_names};
use crate::logical_plan::consumer::{
from_substrait_extended_expr, from_substrait_literal_without_names,
from_substrait_named_struct, from_substrait_plan,
from_substrait_type_without_names,
};
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use datafusion::arrow::array::{
GenericListArray, Int64Builder, MapBuilder, StringBuilder,
};
use datafusion::arrow::datatypes::{Field, Fields, Schema};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::common::{assert_contains, DFSchema};
use datafusion::datasource::{DefaultTableSource, TableProvider};
use datafusion::datasource::empty::EmptyTable;
use datafusion::datasource::{DefaultTableSource, TableProvider};
use datafusion::execution::registry::SerializerRegistry;
use datafusion::execution::{SessionState, SessionStateBuilder};
use datafusion::logical_expr::TableSource;
Expand Down Expand Up @@ -2541,21 +2546,29 @@ mod test {
table: Arc<dyn TableProvider>,
}
impl SerializerRegistry for Registry {
fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result<Vec<u8>> {
fn serialize_custom_table(
&self,
_table: &dyn TableSource,
) -> Result<Vec<u8>> {
Ok("expected payload".as_bytes().to_vec())
}
fn deserialize_custom_table(&self, _name: &str, _bytes: &[u8]) -> Result<Arc<dyn TableSource>> {
fn deserialize_custom_table(
&self,
_name: &str,
_bytes: &[u8],
) -> Result<Arc<dyn TableSource>> {
Ok(Arc::new(DefaultTableSource::new(self.table.clone())))
}
}

async fn round_trip_logical_plans(
local: &SessionContext,
remote: &SessionContext,
table: Arc<dyn TableProvider>
table: Arc<dyn TableProvider>,
) -> Result<(LogicalPlan, LogicalPlan)> {
local.register_table("custom_table", table)?;
let initial_plan = local.sql("select id from custom_table")
let initial_plan = local
.sql("select id from custom_table")
.await?
.logical_plan()
.clone();
Expand All @@ -2575,18 +2588,17 @@ mod test {
Ok((initial_plan, restored))
}

let empty = Arc::new(EmptyTable::new(Arc::new(
Schema::new([
Arc::new(Field::new("id", DataType::Int32, false)),
Arc::new(Field::new("name", DataType::Utf8, false)),
])
)));
let empty = Arc::new(EmptyTable::new(Arc::new(Schema::new([
Arc::new(Field::new("id", DataType::Int32, false)),
Arc::new(Field::new("name", DataType::Utf8, false)),
]))));

let first_attempt = round_trip_logical_plans(
&SessionContext::new(),
&SessionContext::new(),
empty.clone()
).await;
empty.clone(),
)
.await;
assert_eq!(
first_attempt.unwrap_err().to_string(),
"Error during planning: No table named 'custom_table'"
Expand All @@ -2606,16 +2618,25 @@ mod test {

assert_eq!(
initial_plan.to_string(),
restored.to_string()
restored
.to_string()
// substrait will add an explicit projection with the full schema
.replace(
"TableScan: custom_table projection=[id, name]",
"TableScan: custom_table"
)
);
assert_eq!(
local.execute_logical_plan(initial_plan).await?.collect().await?,
remote.execute_logical_plan(restored).await?.collect().await?,
local
.execute_logical_plan(initial_plan)
.await?
.collect()
.await?,
remote
.execute_logical_plan(restored)
.await?
.collect()
.await?,
);
Ok(())
}
Expand Down

0 comments on commit a65e926

Please sign in to comment.