Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Commit

Permalink
refactor dag creation
Browse files Browse the repository at this point in the history
  • Loading branch information
connortsui20 committed Feb 25, 2024
1 parent 9a6aa87 commit 2256190
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 122 deletions.
226 changes: 124 additions & 102 deletions eggstrain/src/execution/query_dag.rs
Original file line number Diff line number Diff line change
@@ -1,97 +1,120 @@
use super::operators::filter::Filter;
use super::operators::project::Project;
use super::operators::UnaryOperator;
use super::operators::{BinaryOperator, UnaryOperator};
use crate::BATCH_SIZE;
use arrow::record_batch::RecordBatch;
use datafusion::physical_plan::aggregates::AggregateExec;
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::{projection::ProjectionExec, ExecutionPlan, Partitioning};
use datafusion::physical_plan::joins::HashJoinExec;
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
use datafusion_common::{DataFusionError, Result};
use futures::stream::StreamExt;
use std::any::TypeId;
use std::collections::VecDeque;
use std::sync::Arc;
use tokio::sync::broadcast;

const BATCH_SIZE: usize = 1024;

#[non_exhaustive]
/// Represents the Operators supported by the `eggstrain` execution engine.
///
/// TODO docs
#[derive(Clone)]
pub(crate) enum EggstrainOperator {
Project(Arc<dyn UnaryOperator<In = RecordBatch, Out = RecordBatch>>),
Filter(Arc<dyn UnaryOperator<In = RecordBatch, Out = RecordBatch>>),
// Sort(Arc<dyn UnaryNode>),

// Aggregate(Arc<dyn UnaryNode>),

// TableScan(Arc<dyn UnaryNode>),

// HashJoin(Arc<dyn BinaryNode>),
// TODO remove `dead_code` once implemented
#[allow(dead_code)]
Sort(Arc<dyn UnaryOperator<In = RecordBatch, Out = RecordBatch>>),
#[allow(dead_code)]
Aggregate(Arc<dyn UnaryOperator<In = RecordBatch, Out = RecordBatch>>),
#[allow(dead_code)]
TableScan(Arc<dyn UnaryOperator<In = RecordBatch, Out = RecordBatch>>),
#[allow(dead_code)]
HashJoin(
Arc<dyn BinaryOperator<InLeft = RecordBatch, InRight = RecordBatch, Out = RecordBatch>>,
),
}

impl EggstrainOperator {
/// Extracts the inner value and calls the `children` method.
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
match self {
Self::Project(x) => x.children(),
Self::Filter(x) => x.children(),
_ => unimplemented!(),
}
}
}

fn extract_df_node(plan: Arc<dyn ExecutionPlan>) -> Result<EggstrainOperator> {
// Cast the plan as an Any type
// Match against the type id
// If it matches a specific one, try to downcast_ref
// else return an error

/// Takes as input a DataFusion `ExecutionPlan` (physical plan), and attempts to parse the root node
/// into an `EggstrainOperator`.
///
/// In order to do this, we make use of the `Any` type to downcast the trait object into a real
/// Rust type.
///
/// TODO docs
fn parse_execution_plan_root(plan: &Arc<dyn ExecutionPlan>) -> Result<EggstrainOperator> {
let root = plan.as_any();
let id = root.type_id();

if id == TypeId::of::<ProjectionExec>() {
let projection_plan = root
.downcast_ref::<ProjectionExec>()
.expect("Unable to downcast_ref to ProjectionExec");
let Some(projection_plan) = root.downcast_ref::<ProjectionExec>() else {
return Err(DataFusionError::NotImplemented(
"Unable to downcast DataFusion ExecutionPlan to ProjectionExec".to_string(),
));
};

let child_schema = projection_plan.children()[0].schema();
let node = Project::new(child_schema, projection_plan);

let node = Project::new(child_schema, projection_plan).into_unary();

Ok(EggstrainOperator::Project(node))
Ok(EggstrainOperator::Project(node.into_unary()))
} else if id == TypeId::of::<FilterExec>() {
let filter_plan = root
.downcast_ref::<FilterExec>()
.expect("Unable to downcast_ref to FilterExec");

let node =
Filter::new(filter_plan.predicate().clone(), filter_plan.children()).into_unary();

Ok(EggstrainOperator::Filter(node))

// } else if id == TypeId::of::<HashJoinExec>() {
// todo!();
// } else if id == TypeId::of::<SortExec>() {
// todo!();
// } else if id == TypeId::of::<AggregateExec>() {
// todo!();
let Some(filter_plan) = root.downcast_ref::<FilterExec>() else {
return Err(DataFusionError::NotImplemented(
"Unable to downcast DataFusion ExecutionPlan to ProjectionExec".to_string(),
));
};

let predicate = filter_plan.predicate().clone();
let children = filter_plan.children();
let node = Filter::new(predicate, children);

Ok(EggstrainOperator::Filter(node.into_unary()))
} else if id == TypeId::of::<HashJoinExec>() {
unimplemented!("HashJoin not implemented");
} else if id == TypeId::of::<SortExec>() {
unimplemented!("Sort not implemented");
} else if id == TypeId::of::<AggregateExec>() {
unimplemented!("Aggregate not implemented");
} else {
Err(DataFusionError::NotImplemented(
"Other operators not implemented".to_string(),
))
}
}

fn df_execute_node(plan: Arc<dyn ExecutionPlan>, tx: broadcast::Sender<RecordBatch>) {
let partitioning = plan.output_partitioning();
let partitions = match partitioning {
/// Wrapper around DataFusion's `ExecutionPlan::execute` to integrate it with `eggstrain`'s
/// execution architecture.
///
/// TODO docs
fn datafusion_execute(plan: Arc<dyn ExecutionPlan>, tx: broadcast::Sender<RecordBatch>) {
// DataFusion execution nodes will output multiple streams that are partitioned by the following
// patterns, so just join them all into one stream
let partitions = match plan.output_partitioning() {
Partitioning::RoundRobinBatch(c) => c,
Partitioning::Hash(_, h) => h,
Partitioning::UnknownPartitioning(p) => p,
};

// In a separate tokio task, send batches to the next operator over the `tx` channel, and make
// sure to make use of all of the partitions
tokio::spawn(async move {
for i in 0..partitions {
let batch_stream = plan.execute(i, Default::default()).unwrap();

let results = batch_stream.collect::<Vec<_>>().await;

for batch in results {
let batch = batch.unwrap();
if batch.num_rows() == 0 {
Expand All @@ -105,74 +128,73 @@ fn df_execute_node(plan: Arc<dyn ExecutionPlan>, tx: broadcast::Sender<RecordBat
});
}

pub fn build_query_dag(plan: Arc<dyn ExecutionPlan>) -> Result<broadcast::Receiver<RecordBatch>> {
// A tuple containing a plan node and a sender into that node
/// Sets up a unary operator in the execution DAG.
///
/// TODO docs
fn setup_unary_operator(
queue: &mut VecDeque<(EggstrainOperator, broadcast::Sender<RecordBatch>)>,
node: EggstrainOperator,
tx: broadcast::Sender<RecordBatch>,
) {
// Defines the edge between operators in the DAG
let (child_tx, child_rx) = broadcast::channel::<RecordBatch>(BATCH_SIZE);

// Create the operator's tokio task
match node.clone() {
EggstrainOperator::Project(eggnode) | EggstrainOperator::Filter(eggnode) => {
let tx = tx.clone();
tokio::spawn(async move {
eggnode.execute(child_rx, tx).await;
});
}
_ => unimplemented!(),
};

let child_plan = node.children()[0].clone();

// If we do not know how to deal with a DataFusion execution plan node, then default to just
// executing DataFusion for that node
match parse_execution_plan_root(&child_plan) {
Ok(val) => {
queue.push_back((val, child_tx));
}
Err(_) => {
datafusion_execute(child_plan, child_tx);
}
}
}

/// Builds the execution DAG.
///
/// Note: Not actually a DAG right now, just a tree.
///
/// TODO docs
pub fn build_execution_dag(
plan: Arc<dyn ExecutionPlan>,
) -> Result<broadcast::Receiver<RecordBatch>> {
// If we do not recognize the root node, then we might as well just use DataFusion completely
let root: EggstrainOperator = parse_execution_plan_root(&plan).map_err(|_| {
DataFusionError::NotImplemented(
"The root node of the input physical plan was not recognized".to_string(),
)
})?;

// A queue of tuples containing an execution operator and the `broadtcast:: Sender` channel side
// to send `RecordBatch`es into that operator
let mut queue = VecDeque::new();

// Final output is going to be sent to root_rx
// Create the topmost channel, where `root_rx` is where the final outputs will be sent
let (root_tx, root_rx) = broadcast::channel::<RecordBatch>(BATCH_SIZE);
// Children of the root will use root_tx to send to the root

let root = extract_df_node(plan)?;

// Run BFS on the `ExecutionPlan` and create our own execution DAG
queue.push_back((root, root_tx));

while let Some((node, tx)) = queue.pop_front() {
let node = node.clone();

match node.children().len() {
0 => {
todo!();
}
1 => {
let (child_tx, child_rx) = broadcast::channel::<RecordBatch>(BATCH_SIZE);
let child_plan = node.children()[0].clone();

match node.clone() {
EggstrainOperator::Project(eggnode) | EggstrainOperator::Filter(eggnode) => {
let tx = tx.clone();
tokio::spawn(async move {
eggnode.execute(child_rx, tx).await;
});
}
};

match extract_df_node(child_plan.clone()) {
Ok(val) => {
queue.push_back((val, child_tx));
}
Err(_) => {
df_execute_node(child_plan, child_tx);
}
}
}
2 => {
todo!();
}
_ => {
return Err(DataFusionError::NotImplemented(
"More than 2 children not implemented".to_string(),
));
}
0 => unimplemented!(),
1 => setup_unary_operator(&mut queue, node, tx),
2 => unimplemented!(),
n => unreachable!("Nodes should not have more than 2 children, saw {}", n),
}

// for child in node.children() {
// let (child_tx, child_rx) = broadcast::channel::<RecordBatch>(BATCH_SIZE);

// if let Ok(child_node) = extract_df_node(child.clone()) {
// match child_node.clone() {
// EggstrainOperator::Project(eggnode) | EggstrainOperator::Filter(eggnode) => {
// let tx = tx.clone();
// tokio::spawn(async move {
// eggnode.execute(child_rx, tx).await;
// });
// }
// };
// queue.push_back((child_node, child_tx));
// } else {
// df_execute_node(child.clone(), tx.clone());
// }
// }
}

Ok(root_rx)
Expand Down
33 changes: 23 additions & 10 deletions eggstrain/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,31 @@ use arrow::record_batch::RecordBatch;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::*;
use datafusion_common::Result;
use std::io;
use execution::query_dag::build_execution_dag;
use std::sync::Arc;

pub mod execution;
pub mod scheduler_client;
pub mod storage_client;

use execution::query_dag::build_query_dag;
const BATCH_SIZE: usize = 1024;

pub async fn tpch_dataframe() -> Result<DataFrame> {
/// Creates a `SessionContext` that contains the base tables for the TPC-H
/// benchmark.
///
/// The 8 tables are:
/// - customer
/// - lineitem
/// - nation
/// - orders
/// - part
/// - partsupp
/// - region
/// - supplier
///
/// Right now, the data is located in `../data` (in the same directory that
/// `eggstrain` is in now).
pub async fn tpch_ctx() -> Result<SessionContext> {
let ctx = SessionContext::new();

let tables = [
Expand All @@ -27,22 +42,20 @@ pub async fn tpch_dataframe() -> Result<DataFrame> {
.await?;
}

let stdin = io::read_to_string(io::stdin())?;

ctx.sql(&stdin).await
Ok(ctx)
}

/// Runs the `eggstrain` execution engine given a DataFusion `ExecutionPlan` (physical plan).
///
/// TODO docs
pub async fn run(plan: Arc<dyn ExecutionPlan>) -> Vec<RecordBatch> {
// Parse the execution plan into a DAG of operators
// where operators are nodes and the edges are broadcasting tokio channels

let mut root_rx = build_query_dag(plan).unwrap();
let mut root_rx = build_execution_dag(plan).unwrap();

// Once we have the DAG, call .await on the top node and hope that
// tokio does it job

let mut all_values = vec![];

while let Ok(x) = root_rx.recv().await {
all_values.push(x);
}
Expand Down
18 changes: 10 additions & 8 deletions eggstrain/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
use arrow::util::pretty;
use datafusion_common::Result;
use eggstrain::{run, tpch_dataframe};
use eggstrain::{run, tpch_ctx};
use std::io;

#[tokio::main]
async fn main() -> Result<()> {
let tpch = tpch_dataframe().await?;
// Create a SessionContext with TPCH base tables
let ctx = tpch_ctx().await?;

let physical_plan = tpch.clone().create_physical_plan().await?;
// Create a DataFrame with the input query
let query = io::read_to_string(io::stdin())?;
let sql = ctx.sql(&query).await?;

println!("{:#?}", physical_plan.clone());

// let physical_plan = physical_plan.children()[0].clone();

let results = run(physical_plan).await;
// Run our execution engine on the physical plan
let df_physical_plan = sql.clone().create_physical_plan().await?;
let results = run(df_physical_plan).await;

results.into_iter().for_each(|batch| {
let pretty_results = pretty::pretty_format_batches(&[batch]).unwrap().to_string();
Expand Down
2 changes: 1 addition & 1 deletion eggstrain/src/scheduler_client/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1 @@

// TODO
2 changes: 1 addition & 1 deletion eggstrain/src/storage_client/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1 @@

// TODO

0 comments on commit 2256190

Please sign in to comment.